diff --git a/Prism/LLaDA/LLaDA_Baseline/dllm_eval/api/__init__.py b/Prism/LLaDA/LLaDA_Baseline/dllm_eval/api/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Prism/LLaDA/LLaDA_Baseline/dllm_eval/api/filter.py b/Prism/LLaDA/LLaDA_Baseline/dllm_eval/api/filter.py new file mode 100644 index 0000000000000000000000000000000000000000..bddbf3ab8d1bcbba804f9790ef0290d437bcde69 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Baseline/dllm_eval/api/filter.py @@ -0,0 +1,56 @@ +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Callable, Iterable, List, Union + +from dllm_eval.api.instance import Instance + + +class Filter(ABC): + """ + Filter classes operate on a per-task level. + They take all model outputs (`instance.resps` for all `task.instances`) + across all instances of a task, and perform operations. + In a single run, one can configure any number of separate filters or lists of filters. + + """ + + def __init__(self, **kwargs) -> None: + """ + Can define custom behavior here, if an individual instantiation of a Filter class should have state. + """ + + @abstractmethod + def apply(self, resps: Union[List, Iterable], docs: List[dict]) -> Iterable: + """ + Defines the operation to perform on a list of the `inst.resps` properties of `Instance` objects. + Should return the list of (filtered) response lists *in the same order as they were input*, e.g. + if pass in [, ] should return + [, ] + """ + return resps + + +@dataclass +class FilterEnsemble: + """ + FilterEnsemble creates a pipeline applying multiple filters. + Its intended usage is to stack multiple post-processing steps in order. + `task.apply_filters` should use a list of FilterEnsemble classes that it stores, to apply each + pipeline separately. + """ + + name: str + filters: List[Callable[[], Filter]] + + def apply(self, instances: List[Instance]) -> None: + resps, docs = zip(*((inst.resps, inst.doc) for inst in instances)) + resps, docs = list(resps), list(docs) + + for f in self.filters: + # apply filters in sequence + resps = f().apply(resps, docs) + + # add the end results after filtering to filtered_requests of their respective source instances. + # has key `self.name`: each FilterEnsemble applied in a given run should use a different name. + for inst, resp in zip(instances, resps): + inst.filtered_resps[self.name] = resp diff --git a/Prism/LLaDA/LLaDA_Baseline/dllm_eval/api/group.py b/Prism/LLaDA/LLaDA_Baseline/dllm_eval/api/group.py new file mode 100644 index 0000000000000000000000000000000000000000..0c60739bbd26c79ecab91f54240798b2ae9e3313 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Baseline/dllm_eval/api/group.py @@ -0,0 +1,115 @@ +import abc +from dataclasses import asdict, dataclass +from inspect import getsource +from typing import Any, Callable, List, Optional, Union + + +@dataclass +class AggMetricConfig(dict): + metric: Optional[str] = None + aggregation: Optional[str] = "mean" + weight_by_size: Optional[str] = False + # list of filter names which should be incorporated into the aggregated metric. + filter_list: Optional[Union[str, list]] = "none" + + def __post_init__(self): + if self.aggregation != "mean" and not callable(self.aggregation): + raise ValueError( + f"Currently, 'mean' is the only pre-defined aggregation across groups' subtasks. Got '{self.aggregation}'." + ) + + if isinstance(self.filter_list, str): + self.filter_list = [self.filter_list] + + +@dataclass +class GroupConfig(dict): + group: Optional[str] = None + group_alias: Optional[str] = None + task: Optional[Union[str, list]] = None + aggregate_metric_list: Optional[ + Union[List[AggMetricConfig], AggMetricConfig, dict] + ] = None + metadata: Optional[dict] = ( + None # by default, not used in the code. allows for users to pass arbitrary info to tasks + ) + + def __getitem__(self, item): + return getattr(self, item) + + def __setitem__(self, item, value): + return setattr(self, item, value) + + def __post_init__(self): + if self.aggregate_metric_list is not None: + if isinstance(self.aggregate_metric_list, dict): + self.aggregate_metric_list = [self.aggregate_metric_list] + + self.aggregate_metric_list = [ + AggMetricConfig(**item) if isinstance(item, dict) else item + for item in self.aggregate_metric_list + ] + + def to_dict(self, keep_callable: bool = False) -> dict: + """dumps the current config as a dictionary object, as a printable format. + null fields will not be printed. + Used for dumping results alongside full task configuration + + :return: dict + A printable dictionary version of the TaskConfig object. + + # TODO: should any default value in the TaskConfig not be printed? + """ + cfg_dict = asdict(self) + # remove values that are `None` + for k, v in list(cfg_dict.items()): + if callable(v): + cfg_dict[k] = self.serialize_function(v, keep_callable=keep_callable) + return cfg_dict + + def serialize_function( + self, value: Union[Callable, str], keep_callable=False + ) -> Union[Callable, str]: + """Serializes a given function or string. + + If 'keep_callable' is True, the original callable is returned. + Otherwise, attempts to return the source code of the callable using 'getsource'. + """ + if keep_callable: + return value + else: + try: + return getsource(value) + except (TypeError, OSError): + return str(value) + + +class ConfigurableGroup(abc.ABC): + def __init__( + self, + config: Optional[dict] = None, + ) -> None: + self._config = GroupConfig(**config) + + @property + def group(self): + return self._config.group + + @property + def group_alias(self): + return self._config.group_alias + + @property + def version(self): + return self._config.version + + @property + def config(self): + return self._config.to_dict() + + @property + def group_name(self) -> Any: + return self._config.group + + def __repr__(self): + return f"ConfigurableGroup(group={self.group},group_alias={self.group_alias})" diff --git a/Prism/LLaDA/LLaDA_Baseline/dllm_eval/api/instance.py b/Prism/LLaDA/LLaDA_Baseline/dllm_eval/api/instance.py new file mode 100644 index 0000000000000000000000000000000000000000..d3c6afa0644e729ba441728c72a2469fdad07b8f --- /dev/null +++ b/Prism/LLaDA/LLaDA_Baseline/dllm_eval/api/instance.py @@ -0,0 +1,38 @@ +from dataclasses import dataclass, field +from typing import Literal, Optional, Tuple + + +OutputType = Literal[ + "loglikelihood", "loglikelihood_rolling", "generate_until", "multiple_choice" +] + + +@dataclass +class Instance: + request_type: OutputType + doc: dict + arguments: tuple + idx: int + metadata: Tuple[Optional[str], Optional[int], Optional[int]] = field( + default_factory=lambda: (None, None, None) + ) + resps: list = field(default_factory=list) + filtered_resps: dict = field(default_factory=dict) + + # initialized after init + task_name: Optional[str] = None + doc_id: Optional[int] = None + repeats: Optional[int] = None + + def __post_init__(self) -> None: + # unpack metadata field + self.task_name, self.doc_id, self.repeats = self.metadata + + @property + def args(self): + """ + Returns (string,) where `string` is the string to calculate loglikelihood over + """ + return ( + self.arguments if isinstance(self.arguments, tuple) else (self.arguments,) + ) diff --git a/Prism/LLaDA/LLaDA_Baseline/dllm_eval/api/metrics.py b/Prism/LLaDA/LLaDA_Baseline/dllm_eval/api/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..2aff6ce92a154a05df3d0bb7d28e09071cd12fbc --- /dev/null +++ b/Prism/LLaDA/LLaDA_Baseline/dllm_eval/api/metrics.py @@ -0,0 +1,578 @@ +import logging +import math +import random +import re +import string +from collections.abc import Iterable +from typing import List + +import numpy as np +import sacrebleu + +from dllm_eval.api.registry import register_aggregation, register_metric + + +eval_logger = logging.getLogger(__name__) + + +# Register Aggregations First +@register_aggregation("bypass") +def bypass_agg(arr): + return 999 + + +@register_aggregation("nanmean") +def nanmean(arr): + if len(arr) == 0 or all(np.isnan(arr)): + return np.nan + return np.nanmean(arr) + + +@register_aggregation("mean") +def mean(arr): + return sum(arr) / len(arr) + + +@register_aggregation("median") +def median(arr): + return arr[len(arr) // 2] + + +# Certain metrics must be calculated across all documents in a benchmark. +# We use them as aggregation metrics, paired with no-op passthrough metric fns. +@register_aggregation("perplexity") +def perplexity(items): + return math.exp(-mean(items)) + + +@register_aggregation("weighted_perplexity") +def weighted_perplexity(items): + return math.exp(-weighted_mean(items)) + + +@register_aggregation("bits_per_byte") +def bits_per_byte(items): + return -weighted_mean(items) / math.log(2) + + +@register_aggregation("f1") +def f1_score(items): + from sklearn.metrics import f1_score + + unzipped_list = list(zip(*items)) + golds = unzipped_list[0] + preds = unzipped_list[1] + fscore = f1_score(golds, preds) + + return np.max(fscore) + + +@register_aggregation("matthews_corrcoef") +def matthews_corrcoef(items): + from sklearn.metrics import matthews_corrcoef + + unzipped_list = list(zip(*items)) + golds = unzipped_list[0] + preds = unzipped_list[1] + return matthews_corrcoef(golds, preds) + + +@register_aggregation("bleu") +def bleu(items): + """The Bilingual Evaluation Understudy Score, or BLEU for short, is a metric + for evaluating a generated sentence to a reference sentence. It counts matching + n-grams in the candidate translation to n-grams in the reference text, where + 1-gram or unigram would be each token and a bigram comparison would be each + word pair. The comparison is made regardless of word order + Source: https://machinelearningmastery.com/calculate-bleu-score-for-text-python/ + Paper: https://www.aclweb.org/anthology/P02-1040/ + + Higher is better + """ + refs = list(zip(*items))[0] + preds = list(zip(*items))[1] + refs, preds = _sacreformat(refs, preds) + return sacrebleu.corpus_bleu(preds, refs).score + + +@register_aggregation("chrf") +def chrf(items): + """chrF++ is a tool for automatic evaluation of machine translation output + based on character n-gram precision and recall enhanced with word n-grams. + Source: https://github.com/m-popovic/chrF + Paper: https://www.aclweb.org/anthology/W15-3049.pdf + + Higher is better # TODO I think + """ + refs = list(zip(*items))[0] + preds = list(zip(*items))[1] + refs, preds = _sacreformat(refs, preds) + return sacrebleu.corpus_chrf(preds, refs).score + + +@register_aggregation("ter") +def ter(items): + """Translation Error Rate is an error metric for machine translation that + measures the number of edits required to change a system output into one + of the references + Source: http://www.cs.umd.edu/~snover/tercom/ + Paper: http://mt-archive.info/AMTA-2006-Snover.pdf + + Lower is better + """ + refs = list(zip(*items))[0] + preds = list(zip(*items))[1] + refs, preds = _sacreformat(refs, preds) + return sacrebleu.corpus_ter(preds, refs).score + + +@register_aggregation("brier_score") +def brier_score(items): # This is a passthrough function + gold, predictions = list(zip(*items)) + bs, num_class = np.array(predictions).shape + + gold = list(gold) + gold_one_hot = np.eye(num_class)[gold] + return np.mean(np.sum((predictions - gold_one_hot) ** 2, axis=1)) + + +@register_metric( + metric="brier_score", + higher_is_better=False, + output_type=["multiple_choice"], + aggregation="brier_score", +) +def brier_score_fn(items): # This is a passthrough function + return items + + +@register_metric( + metric="acc", + higher_is_better=True, + output_type=["loglikelihood", "multiple_choice"], + aggregation="mean", +) +def acc_fn(items): # This is a passthrough function + return items + + +@register_metric( + metric="acc_norm", + higher_is_better=True, + output_type=["loglikelihood", "multiple_choice"], + aggregation="mean", +) +def acc_norm_fn(items): # This is a passthrough function + return items + + +@register_metric( + metric="acc_mutual_info", + higher_is_better=True, + output_type="multiple_choice", + aggregation="mean", +) +def acc_mutual_info_fn(items): # This is a passthrough function + return items + + +### the code used in the `exact_match_hf_evaluate` function is ported from +### https://github.com/huggingface/evaluate/blob/main/metrics/exact_match/exact_match.py +### which is under the apache license. + +# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +def exact_match_hf_evaluate( + predictions, + references, + regexes_to_ignore=None, + ignore_case=False, + ignore_punctuation=False, + ignore_numbers=False, +): + if regexes_to_ignore is not None: + for s in regexes_to_ignore: + predictions = np.array([re.sub(s, "", x) for x in predictions]) + references = np.array([re.sub(s, "", x) for x in references]) + else: + predictions = np.asarray(predictions) + references = np.asarray(references) + + if ignore_case: + predictions = np.char.lower(predictions) + references = np.char.lower(references) + + if ignore_punctuation: + repl_table = string.punctuation.maketrans("", "", string.punctuation) + predictions = np.char.translate(predictions, table=repl_table) + references = np.char.translate(references, table=repl_table) + + if ignore_numbers: + repl_table = string.digits.maketrans("", "", string.digits) + predictions = np.char.translate(predictions, table=repl_table) + references = np.char.translate(references, table=repl_table) + + score_list = predictions == references + + return {"exact_match": np.mean(score_list)} + + +### + + +@register_metric( + metric="exact_match", + higher_is_better=True, + output_type="generate_until", + aggregation="mean", +) +def exact_match_fn(**kwargs): + return exact_match_hf_evaluate(**kwargs) + + +@register_metric( + metric="perplexity", + higher_is_better=False, + output_type="loglikelihood", + aggregation="perplexity", +) +def perplexity_fn(items): # This is a passthrough function + return items + + +@register_metric( + metric="word_perplexity", + higher_is_better=False, + output_type="loglikelihood_rolling", + aggregation="weighted_perplexity", +) +def word_perplexity_fn(items): # This is a passthrough function + return items + + +@register_metric( + metric="byte_perplexity", + higher_is_better=False, + output_type="loglikelihood_rolling", + aggregation="weighted_perplexity", +) +def byte_perplexity_fn(items): # This is a passthrough function + return items + + +@register_metric( + metric="bits_per_byte", + higher_is_better=False, + output_type="loglikelihood_rolling", + aggregation="bits_per_byte", +) +def bits_per_byte_fn(items): # This is a passthrough function + return items + + +def pop_stddev(arr): + mu = mean(arr) + return math.sqrt(sum([(x - mu) ** 2 for x in arr]) / len(arr)) + + +def sample_stddev(arr): + mu = mean(arr) + return math.sqrt(sum([(x - mu) ** 2 for x in arr]) / (len(arr) - 1)) + + +def mean_stderr(arr): + return sample_stddev(arr) / math.sqrt(len(arr)) + + +@register_metric( + metric="bypass", + higher_is_better=True, + output_type=["loglikelihood", "multiple_choice", "generate_until"], + aggregation="bypass", +) +def bypass(items): + return None + + +@register_metric( + metric="mcc", + higher_is_better=True, + output_type="multiple_choice", + aggregation="matthews_corrcoef", +) +def mcc_fn(items): # This is a passthrough function + return items + + +@register_metric( + metric="f1", + higher_is_better=True, + output_type="multiple_choice", + aggregation="f1", +) +def f1_fn(items): # This is a passthrough function + return items + + +@register_metric( + metric="bleu", + higher_is_better=True, + output_type="generate_until", + aggregation="bleu", +) +def bleu_fn(items): # This is a passthrough function + return items + + +@register_metric( + metric="chrf", + higher_is_better=True, + output_type="generate_until", + aggregation="chrf", +) +def chrf_fn(items): # This is a passthrough function + return items + + +@register_metric( + metric="ter", + higher_is_better=True, + output_type="generate_until", + aggregation="ter", +) +def ter_fn(items): # This is a passthrough function + return items + + +@register_metric( + metric="acc_all", + higher_is_better=True, + output_type="loglikelihood", + aggregation="mean", +) +def acc_all(items): + # Only count as correct if all answers are labeled correctly for each question + question_scoring_dict = {} + preds = list(zip(*items))[0] + docs = list(zip(*items))[1] + + for doc, pred in zip(docs, preds): + paragraph_id = doc["idx"]["paragraph"] + question_id = doc["idx"]["question"] + if (paragraph_id, question_id) not in question_scoring_dict: + question_scoring_dict[(paragraph_id, question_id)] = [] + + gold_label = doc["label"] == 1 + + question_scoring_dict[(paragraph_id, question_id)].append(gold_label == pred) + acc = np.mean([int(all(x)) for x in question_scoring_dict.values()]) + return acc + + +def acc_all_stderr(items): + # Only count as correct if all answers are labeled correctly for each question + question_scoring_dict = {} + preds = list(zip(*items))[0] + docs = list(zip(*items))[1] + + for doc, pred in zip(docs, preds): + question_id = doc["idx"]["question"] + if question_id not in question_scoring_dict: + question_scoring_dict[question_id] = [] + + gold_label = doc["label"] == 1 + question_scoring_dict[question_id].append(gold_label == pred) + + acc = mean_stderr([int(all(x)) for x in question_scoring_dict.values()]) + return acc + + +def metric_max_over_ground_truths(metric_fn, prediction, ground_truths): + """Compute max metric between prediction and each ground truth.""" + scores_for_ground_truths = [] + for ground_truth in ground_truths: + score = metric_fn(prediction, ground_truth) + scores_for_ground_truths.append(score) + return max(scores_for_ground_truths) + + +def weighted_mean(items): + a, b = zip(*items) + return sum(a) / sum(b) + + +def is_non_str_iterable(obj): + return isinstance(obj, Iterable) and not isinstance(obj, str) + + +def _sacreformat(refs, preds): + """Format refs and preds for sacrebleu corpus calculation. It is very particular""" + # Sacrebleu expects (List[str], List[List[str]) + # e.g. sacrebleu.corpus_bleu([pred_t], [[ref1_stream], [ref2_stream], ...]) + + # Note [ref1_stream] is the first reference for each pred. + # So lists are size N and (M, N) for N preds and M possible refs for each pred + # This is a different order of dimensions that I would expect + + # We expect refs to be List[str] or List[List[str]], the outer list corresponding to preds + # Must become List[List[str]] with the inner list corresponding to preds + if not is_non_str_iterable(refs): + refs = list(refs) + if not is_non_str_iterable(refs[0]): + refs = [[ref] for ref in refs] + refs = list(zip(*refs)) + # Note the number of refs in each ref list much match the number of preds + + # We expect preds to be List[str] or List[List[str]]. Must become List[str] + if not is_non_str_iterable(preds): + preds = list(preds) + if is_non_str_iterable(preds[0]): + assert len(preds[0]) == 1, f"Pred must be a str, was {preds[0]}" + preds = [pred[0] for pred in preds] + + return refs, preds + + +# stderr stuff + + +class _bootstrap_internal: + def __init__(self, f, n) -> None: + self.f = f + self.n = n + + def __call__(self, v): + i, xs = v + rnd = random.Random() + rnd.seed(i) + res = [] + for _ in range(self.n): + res.append(self.f(rnd.choices(xs, k=len(xs)))) + return res + + +def bootstrap_stderr(f, xs, iters): + import multiprocessing as mp + + pool = mp.Pool(mp.cpu_count()) + # this gives a biased estimate of the stderr (i.e w/ the mean, it gives something + # equivalent to stderr calculated without Bessel's correction in the stddev. + # Unfortunately, I haven't been able to figure out what the right correction is + # to make the bootstrap unbiased - i considered multiplying by sqrt(n/(n-1)) but + # that would be ad-hoc and I can't prove that that would actually be an unbiased estimator) + # Thankfully, shouldn't matter because our samples are pretty big usually anyways + res = [] + chunk_size = min(1000, iters) + from tqdm import tqdm + + print("bootstrapping for stddev:", f.__name__) + for bootstrap in tqdm( + pool.imap( + _bootstrap_internal(f, chunk_size), + [(i, xs) for i in range(iters // chunk_size)], + ), + total=iters // chunk_size, + ): + # sample w replacement + res.extend(bootstrap) + + pool.close() + return sample_stddev(res) + + +def stderr_for_metric(metric, bootstrap_iters: int): + if bootstrap_iters <= 0: + # return no function (don't compute stderr) if bootstrap iters = 0 + return None + + bootstrappable = [ + median, + matthews_corrcoef, + f1_score, + perplexity, + bleu, + chrf, + ter, + nanmean, + ] + + if metric in bootstrappable: + return lambda x: bootstrap_stderr(metric, x, iters=bootstrap_iters) + + stderr = {mean: mean_stderr, acc_all: acc_all_stderr} + + return stderr.get(metric, None) + + +def pooled_sample_stderr(stderrs: List[float], sizes: List[int]): + # Used to aggregate bootstrapped stderrs across subtasks in a group, + # when we are weighting by the size of each subtask. + # + + assert len(stderrs) == len(sizes) + + # formula source: https://en.wikipedia.org/wiki/Pooled_variance + # and: https://stats.stackexchange.com/a/4841331 + # this empirically seems to match running `stderr_for_metric` on all instances + # from the subtasks concatenated with each other. + pooled_sample_var = ( + sum([(size - 1) * stderr**2 * size for size, stderr in zip(sizes, stderrs)]) + ) / (sum(sizes) - len(sizes)) + + return np.sqrt(pooled_sample_var / sum(sizes)) + + +def combined_sample_stderr(stderrs: List[float], sizes: List[int], metrics=None): + assert metrics is not None, ( + "Need to pass a list of each subtask's metric for this stderr aggregation" + ) + assert len(stderrs) == len(sizes) and len(sizes) == len(metrics) + + # See https://github.com/EleutherAI/lm-evaluation-harness/pull/1390 for more documentation. + # This formula depends on sample means. + # removed because it seems to give erroneously huge stderrs for groupings of tasks + # and does not seem to match up with bootstrap-calculated stderrs for groups. + + ### don't use this unless a statistician has told you it's the right thing to do ### + + # accumulators: we'll aggregate pairwise N - 1 times + variance = stderrs[0] ** 2 + curr_size = sizes[0] + curr_score = metrics[0] + + for stderr, size, score in zip(stderrs[1:], sizes[1:], metrics[1:]): + curr_score = ((curr_score * curr_size) + (score * size)) / ( + curr_size + size + ) # NOTE: this assumes our aggregation fn is "mean" + + variance = ((curr_size - 1) * variance + (size - 1) * (stderr**2)) / ( + curr_size + size - 1 + ) + curr_size * size / ((curr_size + size) * (curr_size + size - 1)) * ( + curr_score - score + ) ** 2 + + return np.sqrt(variance) + + +def aggregate_subtask_metrics(metrics, sizes, weight_by_size=True): + # A helper function that is used to aggregate + # subtask scores cross-task. + # TODO: does not hold for non-mean aggregations + if not weight_by_size: + sizes = [1] * len(sizes) + + assert len(metrics) == len(sizes) + + return sum([metric * size for metric, size in zip(metrics, sizes)]) / sum(sizes) diff --git a/Prism/LLaDA/LLaDA_Baseline/dllm_eval/api/model.py b/Prism/LLaDA/LLaDA_Baseline/dllm_eval/api/model.py new file mode 100644 index 0000000000000000000000000000000000000000..9364a9312d78c1029e5edf38d61f192afca91334 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Baseline/dllm_eval/api/model.py @@ -0,0 +1,493 @@ +import abc +import hashlib +import json +import logging +import os +from typing import Dict, List, Optional, Tuple, Type, TypeVar, Union + +import transformers +from sqlitedict import SqliteDict +from tqdm import tqdm + +from dllm_eval import utils + + +eval_logger = logging.getLogger(__name__) + +T = TypeVar("T", bound="LM") + + +class LM(abc.ABC): + def __init__(self) -> None: + """Defines the interface that should be implemented by all LM subclasses. + LMs are assumed to take text (strings) as input and yield strings as output + (inputs/outputs should be tokenization-agnostic.) + + """ + # set rank and world size to a single process, by default. + self._rank = 0 + self._world_size = 1 + self.cache_hook = CacheHook(None) + + @abc.abstractmethod + def loglikelihood(self, requests) -> List[Tuple[float, bool]]: + """Compute log-likelihood of generating a continuation from a context. + Downstream tasks should attempt to use loglikelihood instead of other + LM calls whenever possible. + + :param requests: list[Instance] + A list of Instance objects, with property `args` which returns a tuple (context, continuation). + `context: str` + Context string. Implementations of LM must be able to handle an + empty context string. + `continuation: str` + The continuation over which log likelihood will be calculated. If + there is a word boundary, the space should be in the continuation. + For example, context="hello" continuation=" world" is correct. + + :return: list[tuple[float, bool]] + A list of pairs (logprob, isgreedy) + `logprob: float` + The log probability of `continuation`. + `isgreedy`: + Whether `continuation` would be generated by greedy sampling from `context`. + """ + pass + + @abc.abstractmethod + def loglikelihood_rolling(self, requests) -> List[float]: + """Compute full log-likelihood of a string, with no truncation, for perplexity computation + - We will use the full max context length of the model. + - For inputs that exceed the max context length, we divide the tokenized string into chunks of up to + the max context length. + - IMPORTANT: Each document's loglikelihood/perplexity is computed *separately*, unlike other implementations + which may simply concatenate multiple documents together. + - IMPORTANT: We maximize the amount of context for each prediction. Specifically, for inputs that we break into + multiple chunks, the last input will still a full-sized context. + Example: + Input tokens: [ 0 1 2 3 4 5 6 7 8 9 ] + Prefix: BOS/EOS + Max context length: 4 + Resulting input/prediction pairs: + + INPUT: BOS 0 1 2 + PRED: 0 1 2 3 + + INPUT: 3 4 5 6 + PRED: 4 5 6 7 + + INPUT: 5 6 7 8 + PRED: 8 9 + + Observe that: + 1. Each token is predicted exactly once + 2. For the last pair, we provide the full context, but only score the last two tokens + + :param requests: list[Instance] + A list of Instance objects with property `args` which returns a tuple (context,). + string: str + String for which we are computing overall loglikelihood + :return: list[tuple[float]] + A list of tuples (logprob,) + logprob: float + The log probability of `context` conditioned on the BOS/EOS token. + Can also be overridden for custom cases by `prefix_token_id`. + """ + pass + + # TODO: Add an optional max length + @abc.abstractmethod + def generate_until(self, requests) -> List[str]: + """Generate greedily until a stopping sequence + + :param requests: list[Instance] + A list of Instance objects with property `args` which returns a tuple (context, gen_kwargs). + context: str + Context string + gen_kwargs: dict + A dictionary of keyword arguments to pass to the generation function e.g. top_k, until, etc. + :return: list[str] + A list of model generated continuations. + continuation: str + The generated continuation. + """ + pass + + def apply_chat_template( + self, chat_history: List[Dict[str, str]], add_generation_prompt=True + ) -> str: + """ + Defines how to transform few-shot examples provided as chat history into a format that can be used as input to the LM. + + :param chat_history: list[dict[str, str]] + A list of dictionaries with keys 'role' and 'content'. + Values are strings representing the role name and the content of the message, respectively. + :param add_generation_prompt: bool + Whether to append an assistant gen prefix (for e.g. <|assistant|>) to the assistant messages in the chat history. False if prefilling an assistant message. + :return: str + A string representing the chat history in a format that can be used as input to the LM. + """ + raise NotImplementedError( + "To use this model with chat templates, please implement the 'apply_chat_template' method for your model type." + ) + + @classmethod + def create_from_arg_string( + cls: Type[T], arg_string: str, additional_config: Optional[dict] = None + ) -> T: + """ + Creates an instance of the LM class using the given argument string and additional config. + + Parameters: + - arg_string: A string containing arguments in the format key1=value1,key2=value2. + - additional_config: Optional dictionary containing additional configuration parameters. + + Returns: + - Instance of the LM class. + """ + additional_config = {} if additional_config is None else additional_config + args = utils.simple_parse_args_string(arg_string) + args2 = {k: v for k, v in additional_config.items() if v is not None} + return cls(**args, **args2) + + @classmethod + def create_from_arg_obj( + cls: Type[T], arg_dict: dict, additional_config: Optional[dict] = None + ) -> T: + """ + Creates an instance of the LM class using the given arg_obj + + Parameters: + - arg_obj: A dict containing arguments in the format key1=value1,key2=value2. + - additional_config: Optional dictionary containing additional configuration parameters. + + Returns: + - Instance of the LM class. + """ + + additional_config = {} if additional_config is None else additional_config + additional_config = { + k: v for k, v in additional_config.items() if v is not None + } + + return cls(**arg_dict, **additional_config) + + @property + def rank(self): + # used in the case of parallelism. Hardcoded to + # ensure no errors arise using API models which do + # not support multi-device parallelism nor expect it. + return self._rank + + @property + def world_size(self): + # used in the case of parallelism. Hardcoded to + # ensure no errors arise using API models which do + # not support multi-device parallelism nor expect it. + return self._world_size + + @property + def tokenizer_name(self) -> str: + """Must be defined for LM subclasses which implement Chat Templating. + Should return the name of the tokenizer or chat template used. + Used only to properly fingerprint caches when requests are being cached with `--cache_requests`, otherwise not used. + """ + raise NotImplementedError( + "To use this model with chat templates, please implement the 'tokenizer_name' property." + ) + + def chat_template(self, chat_template: Union[bool, str] = False) -> Optional[str]: + """Returns the chat template structure for user/assistant messages if a template is provided. + This method is intended to be overridden in a subclass to define a specific chat template format. + For models that do not support chat templates, this method returns None by default. + """ + + return "" + + def set_cache_hook(self, cache_hook) -> None: + self.cache_hook = cache_hook + + +### SQLite-based caching of LM responses +def hash_args(attr, args): + dat = json.dumps([attr] + list(args)) + return hashlib.sha256(dat.encode("utf-8")).hexdigest() + + +class CacheHook: + def __init__(self, cachinglm) -> None: + if cachinglm is None: + self.dbdict = None + return + + self.dbdict = cachinglm.dbdict + + def add_partial(self, attr, req, res) -> None: + if self.dbdict is None: + return + hsh = hash_args(attr, req) + self.dbdict[hsh] = res + + +class CachingLM: + def __init__(self, lm, cache_db) -> None: + """LM wrapper that returns cached results if they exist, and uses the underlying LM if not. + + :param lm: LM + Underlying LM + :param cache_db: str + Path to cache db + """ + self.lm = lm + self.cache_db = cache_db + if os.path.dirname(cache_db): + os.makedirs(os.path.dirname(cache_db), exist_ok=True) + self.dbdict = SqliteDict(cache_db, autocommit=True) + + # add hook to lm + lm.set_cache_hook(self.get_cache_hook()) + + def __getattr__(self, attr: str): + lm_attr = getattr(self.lm, attr) + if attr not in ["loglikelihood", "loglikelihood_rolling", "generate_until"]: + eval_logger.debug(f"Passing through attribute '{attr}' to underlying LM") + return lm_attr + + def fn(requests): + res = [] + remaining_reqs = [] + warned = False + # figure out which ones are cached and which ones are new + eval_logger.info( + f"Loading '{attr}' responses from cache '{self.cache_db}' where possible..." + ) + for req in tqdm(requests, desc="Checking cached requests"): + hsh = hash_args(attr, req.args) + if attr == "generate_until" and req.args[1].get("do_sample", False): + # when we are doing non-greedy generation, don't use the cache + # (else every "randomly sampled" generation would be identical for repeats > 1). + if not warned: + eval_logger.warning( + f"Arguments to lm.generate_until() '{req.args[1]}' include non-deterministic sampling. Caching will not be performed for such requests." + ) + warned = True + res.append(None) + remaining_reqs.append(req) + elif hsh in self.dbdict: + ob = self.dbdict[hsh] + + assert ob is not None + + res.append(ob) + else: + res.append(None) + remaining_reqs.append(req) + eval_logger.info( + f"Cached requests: {len(requests) - len(remaining_reqs)}, Requests remaining: {len(remaining_reqs)}" + ) + if remaining_reqs: + # actually run the LM on the requests that do not have cached results + rem_res = getattr(self.lm, attr)(remaining_reqs) + else: + rem_res = [] + + # stick the new ones back into the list and also cache any of the new ones + resptr = 0 + for req, r in zip(remaining_reqs, rem_res): + while res[resptr] is not None: + resptr += 1 + + res[resptr] = r + + # caching + hsh = hash_args(attr, req.args) + self.dbdict[hsh] = r + self.dbdict.commit() + + return res + + return fn + + def get_cache_hook(self): + return CacheHook(self) + + +class TemplateLM(LM): + """ + A class acting as intermediary between the LM base class + and boilerplate often included in other LM subclasses. + """ + + tokenizer = None + + @property + @abc.abstractmethod + def eot_token_id(self): + pass + + @property + def prefix_token_id(self): + # it is used as prefix for loglikelihood + return self.eot_token_id + + @abc.abstractmethod + def tok_encode(self, string: str, **kwargs) -> List[int]: + """ + Tokenize a string using the model's tokenizer and return a list of token IDs. + """ + pass + + @abc.abstractmethod + def _loglikelihood_tokens(self, requests, **kwargs) -> List[Tuple[float, bool]]: + pass + + def _encode_pair( + self, context: str, continuation: str + ) -> Tuple[List[int], List[int]]: + n_spaces = len(context) - len(context.rstrip()) + if n_spaces > 0: + continuation = context[-n_spaces:] + continuation + context = context[:-n_spaces] + + model_class = getattr(self, "AUTO_MODEL_CLASS", None) + + if model_class == transformers.AutoModelForSeq2SeqLM: + context_enc = self.tok_encode(context) + continuation_enc = self.tok_encode(continuation, add_special_tokens=False) + else: + whole_enc = self.tok_encode(context + continuation) + context_enc = self.tok_encode(context) + + context_enc_len = len(context_enc) + continuation_enc = whole_enc[context_enc_len:] + + return context_enc, continuation_enc + + def loglikelihood( + self, requests, disable_tqdm: bool = False + ) -> List[Tuple[float, bool]]: + new_reqs = [] + for context, continuation in [req.args for req in requests]: + if context == "": + # BOS or EOS as context + context_enc, continuation_enc = ( + [self.prefix_token_id], + self.tok_encode(continuation), + ) + else: + context_enc, continuation_enc = self._encode_pair(context, continuation) + + new_reqs.append(((context, continuation), context_enc, continuation_enc)) + + return self._loglikelihood_tokens(new_reqs, disable_tqdm=disable_tqdm) + + @abc.abstractmethod + def loglikelihood_rolling( + self, requests, disable_tqdm: bool = False + ) -> List[float]: + pass + + @abc.abstractmethod + def generate_until(self, requests, disable_tqdm: bool = False) -> List[str]: + pass + + def chat_template(self, chat_template: Union[bool, str] = False) -> Optional[str]: + """ + Set and get the appropriate chat template for the model. + This method sets the tokenizer's chat_template and returns the template string for reproducibility. + + The template selection logic is adapted from the Transformers library's `apply_chat_template` + method in the Tokenizer class. The original implementation can be found at: + https://github.com/huggingface/transformers/blob/fc35907f95459d7a6c5281dfadd680b6f7b620e3/src/transformers/tokenization_utils_base.py#L1687 + + This method ensures that the right template is chosen based on the following: + 0. If the model has no 'tokenizer' attribute: assumes that there is only a single possible chat template, handled on the model provider side internally. Returns the empty string. + 1. If the model's tokenizer has multiple templates: + a. Use the specified template if it exists in the dictionary. + b. Use the default template from the list if no specific template is provided. + c. Raise an error if no default template exists and no specific template is provided. + 2. If the model's tokenizer has a single template or no template: + a. Use the tokenizer's chat template if available. + b. Fall back to the default chat template if no tokenizer chat template exists. + + Args: + chat_template (Union[bool, str]): Specifies the chat template to use. + - If False or None, no template is applied. + - If True, the default or only available template is used. + - If a string, the template with the matching name is used. + + Returns: + Optional[str]: The selected chat template, or None if no template is applied. + """ + if self.tokenizer is None: + return "" + + if chat_template is False or chat_template is None: + eval_logger.warning( + "model.chat_template was called with the chat_template set to False or None. " + "Therefore no chat template will be applied. Make sure this is an intended behavior." + ) + return None + + # Convert boolean chat_template to None to ensure compatibility with the adapted logic + if isinstance(chat_template, bool): + chat_template = None + using_default_template = False + + # First, handle the cases when the model has a dict of multiple templates + try: + template = ( + self.tokenizer.chat_template or self.tokenizer.default_chat_template + ) + except AttributeError: + return None + + if isinstance(template, dict): + using_default_dict = self.tokenizer.chat_template is None + + if chat_template is not None: + if chat_template in template: + selected_template = template[chat_template] + if using_default_dict: + using_default_template = True + else: + raise ValueError( + f"The specified chat template '{chat_template}' is not available. " + f"Available template names are {sorted(template.keys())}." + ) + else: + # If user didn't pass a chat template, use the default template from the dict + if "default" in template: + selected_template = template["default"] + using_default_template = True + else: + raise ValueError( + "This model has multiple chat templates with no default specified! Please either pass a chat " + "template or the name of the template you wish to use to the `chat_template` argument. Available " + f"template names are {sorted(template.keys())}." + ) + + # Cases when the model has a single template or no template + else: + # priority: `chat_template` argument > `tokenizer.chat_template` > `tokenizer.default_chat_template + if isinstance(chat_template, str): + eval_logger.warning( + "Chat template name provided, but the tokenizer's chat template is not a dictionary. " + "Using the tokenizer's chat template or the default template instead." + ) + if self.tokenizer.chat_template is not None: + selected_template = self.tokenizer.chat_template + else: + selected_template = self.tokenizer.default_chat_template + using_default_template = True + + if using_default_template: + eval_logger.warning( + "No chat template is set for this tokenizer, falling back to a default class-level template. This is " + "very error-prone, because models are often trained with templates different from the class default! " + "Default chat templates are a legacy feature and will be removed in Transformers v4.43, at which " + "point any code depending on them will stop working. We recommend setting a valid chat template before " + "then to ensure that this model continues working without issues." + ) + + return selected_template diff --git a/Prism/LLaDA/LLaDA_Baseline/dllm_eval/api/registry.py b/Prism/LLaDA/LLaDA_Baseline/dllm_eval/api/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..bf2b2e415a0a19862a41bde307bbad2e6ba326f5 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Baseline/dllm_eval/api/registry.py @@ -0,0 +1,196 @@ +import logging +from typing import Callable, Dict, Union + +import evaluate as hf_evaluate + +from dllm_eval.api.model import LM + + +eval_logger = logging.getLogger(__name__) + +MODEL_REGISTRY = {} + + +def register_model(*names): + # either pass a list or a single alias. + # function receives them as a tuple of strings + + def decorate(cls): + for name in names: + assert issubclass(cls, LM), ( + f"Model '{name}' ({cls.__name__}) must extend LM class" + ) + + assert name not in MODEL_REGISTRY, ( + f"Model named '{name}' conflicts with existing model! Please register with a non-conflicting alias instead." + ) + + MODEL_REGISTRY[name] = cls + return cls + + return decorate + + +def get_model(model_name): + try: + return MODEL_REGISTRY[model_name] + except KeyError: + raise ValueError( + f"Attempted to load model '{model_name}', but no model for this name found! Supported model names: {', '.join(MODEL_REGISTRY.keys())}" + ) + + +TASK_REGISTRY = {} +GROUP_REGISTRY = {} +ALL_TASKS = set() +func2task_index = {} + + +def register_task(name): + def decorate(fn): + assert name not in TASK_REGISTRY, ( + f"task named '{name}' conflicts with existing registered task!" + ) + + TASK_REGISTRY[name] = fn + ALL_TASKS.add(name) + func2task_index[fn.__name__] = name + return fn + + return decorate + + +def register_group(name): + def decorate(fn): + func_name = func2task_index[fn.__name__] + if name in GROUP_REGISTRY: + GROUP_REGISTRY[name].append(func_name) + else: + GROUP_REGISTRY[name] = [func_name] + ALL_TASKS.add(name) + return fn + + return decorate + + +OUTPUT_TYPE_REGISTRY = {} +METRIC_REGISTRY = {} +METRIC_AGGREGATION_REGISTRY = {} +AGGREGATION_REGISTRY: Dict[str, Callable[[], Dict[str, Callable]]] = {} +HIGHER_IS_BETTER_REGISTRY = {} +FILTER_REGISTRY = {} + +DEFAULT_METRIC_REGISTRY = { + "loglikelihood": [ + "perplexity", + "acc", + ], + "loglikelihood_rolling": ["word_perplexity", "byte_perplexity", "bits_per_byte"], + "multiple_choice": ["acc", "acc_norm"], + "generate_until": ["exact_match"], +} + + +def register_metric(**args): + # TODO: do we want to enforce a certain interface to registered metrics? + def decorate(fn): + assert "metric" in args + name = args["metric"] + + for key, registry in [ + ("metric", METRIC_REGISTRY), + ("higher_is_better", HIGHER_IS_BETTER_REGISTRY), + ("aggregation", METRIC_AGGREGATION_REGISTRY), + ]: + if key in args: + value = args[key] + assert value not in registry, ( + f"{key} named '{value}' conflicts with existing registered {key}!" + ) + + if key == "metric": + registry[name] = fn + elif key == "aggregation": + registry[name] = AGGREGATION_REGISTRY[value] + else: + registry[name] = value + + return fn + + return decorate + + +def get_metric(name: str, hf_evaluate_metric=False) -> Callable: + if not hf_evaluate_metric: + if name in METRIC_REGISTRY: + return METRIC_REGISTRY[name] + else: + eval_logger.warning( + f"Could not find registered metric '{name}' in lm-eval, searching in HF Evaluate library..." + ) + + try: + metric_object = hf_evaluate.load(name) + return metric_object.compute + except Exception: + eval_logger.error( + f"{name} not found in the evaluate library! Please check https://huggingface.co/evaluate-metric", + ) + + +def register_aggregation(name: str): + def decorate(fn): + assert name not in AGGREGATION_REGISTRY, ( + f"aggregation named '{name}' conflicts with existing registered aggregation!" + ) + + AGGREGATION_REGISTRY[name] = fn + return fn + + return decorate + + +def get_aggregation(name: str) -> Callable[[], Dict[str, Callable]]: + try: + return AGGREGATION_REGISTRY[name] + except KeyError: + eval_logger.warning(f"{name} not a registered aggregation metric!") + + +def get_metric_aggregation(name: str) -> Callable[[], Dict[str, Callable]]: + try: + return METRIC_AGGREGATION_REGISTRY[name] + except KeyError: + eval_logger.warning(f"{name} metric is not assigned a default aggregation!") + + +def is_higher_better(metric_name) -> bool: + try: + return HIGHER_IS_BETTER_REGISTRY[metric_name] + except KeyError: + eval_logger.warning( + f"higher_is_better not specified for metric '{metric_name}'!" + ) + + +def register_filter(name): + def decorate(cls): + if name in FILTER_REGISTRY: + eval_logger.info( + f"Registering filter `{name}` that is already in Registry {FILTER_REGISTRY}" + ) + FILTER_REGISTRY[name] = cls + return cls + + return decorate + + +def get_filter(filter_name: Union[str, Callable]) -> Callable: + try: + return FILTER_REGISTRY[filter_name] + except KeyError as e: + if callable(filter_name): + return filter_name + else: + eval_logger.warning(f"filter `{filter_name}` is not registered!") + raise e diff --git a/Prism/LLaDA/LLaDA_Baseline/dllm_eval/api/samplers.py b/Prism/LLaDA/LLaDA_Baseline/dllm_eval/api/samplers.py new file mode 100644 index 0000000000000000000000000000000000000000..969789ef2111dcb8ee3b7eed4c69d54572d6c302 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Baseline/dllm_eval/api/samplers.py @@ -0,0 +1,232 @@ +import logging +import warnings +from functools import partial +from typing import TYPE_CHECKING, Iterable, Optional, Union + +import datasets + + +if TYPE_CHECKING: + from random import Random + + from dllm_eval.api.task import ConfigurableTask, Task + +eval_logger = logging.getLogger("lm-eval") + + +class ContextSampler: + def __init__( + self, + docs: list[dict], + task: Union["Task", "ConfigurableTask"], + fewshot_indices: Optional[Iterable] = None, + rnd: Optional["Random"] = None, + ) -> None: + self.rnd = rnd + if not self.rnd: + raise ValueError( + "A `random.Random` generator argument must be provided to `rnd` of FewShotSampler!" + ) + + self.task = task + self.config = task._config + + self.target_delimiter = self.config.target_delimiter + self.fewshot_delimiter = self.config.fewshot_delimiter + + if ( + self.config.fewshot_config is not None + and self.config.fewshot_config.get("doc_to_text", None) is not None + ): + self.doc_to_text = partial( + self.task.doc_to_text, + doc_to_text=self.config.fewshot_config.get("doc_to_text", None), + ) + else: + self.doc_to_text = self.task.doc_to_text + + if ( + self.config.fewshot_config is not None + and self.config.fewshot_config.get("doc_to_target", None) is not None + ): + self.doc_to_target = partial( + self.task.doc_to_target, + doc_to_target=self.config.fewshot_config.get("doc_to_target", None), + ) + else: + self.doc_to_target = self.task.doc_to_target + + if ( + self.config.fewshot_config is not None + and self.config.fewshot_config.get("doc_to_choice", None) is not None + ): + self.doc_to_choice = partial( + self.task.doc_to_choice, + doc_to_choice=self.config.fewshot_config.get("doc_to_choice", None), + ) + else: + self.doc_to_choice = self.task.doc_to_choice + + self.docs = docs # HF dataset split, provided by task._fewshot_docs() + if fewshot_indices: # subset few-shot docs from + if not isinstance(self.docs, datasets.Dataset): + raise ValueError( + "Got `fewshot_indices` but fewshot_docs are not a HF dataset. Don't use both `fewshot_indices` and a user-defined few-shot sample list simultaneously" + ) + self.docs = self.docs.select(fewshot_indices) + + def get_context(self, doc: dict, num_fewshot: int, gen_prefix: str = None): + # draw an extra fewshot sample if using same split as evaluating on + prefix = gen_prefix + " " if gen_prefix else "" + n_samples = ( + num_fewshot + 1 + if self.config.fewshot_split == self.config.test_split + else num_fewshot + ) + + # draw `n_samples` docs from fewshot_docs + fewshotex = self.sample(n_samples) + + # get rid of the doc that's the one we're evaluating, if it's in the fewshot + # TODO: should we just stop people from using fewshot from same split as evaluating? + selected_docs = [x for x in fewshotex if x != doc][:num_fewshot] + + labeled_examples = "" + for doc in selected_docs: + doc_content = self.doc_to_text(doc) + doc_target = self.doc_to_target(doc) + if self.config.doc_to_choice is None or isinstance(doc_content, str): + labeled_examples += doc_content + else: + labeled_examples += self.doc_to_choice(doc)[doc_content] + + if doc_target != "": + if self.target_delimiter.isspace() and str(doc_target)[0].isspace(): + # TODO: add logger warn once here. + warnings.warn( + "Both target_delimiter and target start with a space. This may cause issues.", + Warning, + stacklevel=2, + ) + labeled_examples += self.target_delimiter + labeled_examples += prefix + labeled_examples += ( + str(doc_target[0]) + if isinstance(doc_target, list) + else doc_target + if self.config.doc_to_choice is None or isinstance(doc_target, str) + else str(self.doc_to_choice(doc)[doc_target]) + ) + labeled_examples += self.fewshot_delimiter + + return labeled_examples + + def get_chat_context( + self, + doc: dict, + num_fewshot: int, + fewshot_as_multiturn: bool = False, + gen_prefix: Optional[str] = None, + ): + # TODO: Do we need any other delimiter + prefix = gen_prefix + " " if gen_prefix else "" + chat_history = [] + # draw an extra fewshot sample if using same split as evaluating on + n_samples = ( + num_fewshot + 1 + if self.config.fewshot_split == self.config.test_split + else num_fewshot + ) + # draw `n_samples` docs from fewshot_docs + fewshotex = self.sample(n_samples) + + # get rid of the doc that's the one we're evaluating, if it's in the fewshot + # TODO: should we just stop people from using fewshot from same split as evaluating? + selected_docs = [x for x in fewshotex if x != doc][:num_fewshot] + + if fewshot_as_multiturn: + for doc in selected_docs: + doc_content = self.doc_to_text(doc) + doc_target = self.doc_to_target(doc) + chat_history.append( + { + "role": "user", + "content": doc_content + if self.config.doc_to_choice is None + or isinstance(doc_content, str) + else self.doc_to_choice(doc)[doc_content], + } + ) + chat_history.append( + { + "role": "assistant", + "content": prefix + str(doc_target[0]) + if isinstance(doc_target, list) + else prefix + doc_target + if self.config.doc_to_choice is None + or isinstance(doc_target, str) + else prefix + str(self.doc_to_choice(doc)[doc_target]), + } + ) + else: + # get fewshot context as one user turn + chat_history.append( + { + "role": "user", + "content": self.get_context( + doc, num_fewshot, gen_prefix=gen_prefix + ), + } + ) + + return chat_history + + def sample(self, n: int): + """ + Draw `n` samples from our fewshot docs. This method should be overridden by subclasses. + """ + + return self.rnd.sample(self.docs, n) + + +class FirstNSampler(ContextSampler): + def sample(self, n: int) -> None: + """ + Draw the first `n` samples in order from the specified split. + Used for tasks with "canonical" ordered fewshot examples, such as MMLU and CMMLU. + """ + assert n <= len(self.docs), ( + f"Error: number of fewshot samples requested exceeds the {len(self.docs)} that are available." + ) + return self.docs[:n] + + +class BalancedSampler(ContextSampler): + def sample(self, n: int) -> None: + """ + TODO: this should return approximately class-balanced samples from our fewshot examples. + TODO: what order should they be in? maybe random? + """ + + pass + + +class ManualSampler(ContextSampler): + def sample(self, n: int) -> None: + """ """ + pass + + +SAMPLER_REGISTRY = { + "default": ContextSampler, + "first_n": FirstNSampler, +} + + +def get_sampler(name: str): + try: + return SAMPLER_REGISTRY[name] + except KeyError: + raise ValueError( + f"Attempted to use contextsampler '{name}', but no sampling strategy for this name found! Supported model names: {', '.join(SAMPLER_REGISTRY.keys())}" + ) diff --git a/Prism/LLaDA/LLaDA_Baseline/dllm_eval/api/task.py b/Prism/LLaDA/LLaDA_Baseline/dllm_eval/api/task.py new file mode 100644 index 0000000000000000000000000000000000000000..4a6321af0b2b8777e0322745a9875656ec194190 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Baseline/dllm_eval/api/task.py @@ -0,0 +1,1881 @@ +import abc +import ast +import logging +import random +import re +from collections.abc import Callable +from copy import deepcopy +from dataclasses import asdict, dataclass +from inspect import getsource +from typing import ( + Any, + Dict, + Iterable, + Iterator, + List, + Literal, + Mapping, + Optional, + Tuple, + Union, +) + +import datasets +import numpy as np +from tqdm import tqdm + +from dllm_eval import utils +from dllm_eval.api import samplers +from dllm_eval.api.instance import Instance, OutputType +from dllm_eval.api.metrics import bits_per_byte, mean, weighted_perplexity +from dllm_eval.api.registry import ( + AGGREGATION_REGISTRY, + DEFAULT_METRIC_REGISTRY, + get_aggregation, + get_metric, + get_metric_aggregation, + is_higher_better, +) +from dllm_eval.caching.cache import load_from_cache, save_to_cache +from dllm_eval.filters import build_filter_ensemble +from dllm_eval.prompts import get_prompt + + +ALL_OUTPUT_TYPES = [ + "loglikelihood", + "multiple_choice", + "loglikelihood_rolling", + "generate_until", +] + +eval_logger = logging.getLogger(__name__) + + +@dataclass +class TaskConfig(dict): + # task naming/registry + task: Optional[str] = None + task_alias: Optional[str] = None + tag: Optional[Union[str, list]] = None + # HF dataset options. + # which dataset to use, + # and what splits for what purpose + custom_dataset: Optional[Callable] = None + dataset_path: Optional[str] = None + dataset_name: Optional[str] = None + dataset_kwargs: Optional[dict] = None + training_split: Optional[str] = None + validation_split: Optional[str] = None + test_split: Optional[str] = None + fewshot_split: Optional[str] = ( + None # TODO: assert that this not None if num_fewshot > 0. (?) assert if this is same split as one evaluating (?) + ) + # formatting / prompting options. + # see docs/advanced_task_guide.md for more info + process_docs: Optional[Callable] = None + doc_to_text: Optional[Union[Callable, str]] = None + doc_to_target: Optional[Union[Callable, str]] = None + doc_to_image: Union[Callable, str] = None + doc_to_audio: Union[Callable, str] = None + unsafe_code: bool = False + doc_to_choice: Optional[Union[Callable, str, dict, list]] = None + process_results: Optional[Union[Callable, str]] = None + use_prompt: Optional[str] = None + description: str = "" + target_delimiter: str = " " + fewshot_delimiter: str = "\n\n" + fewshot_config: Optional[dict] = None + # runtime configuration options + num_fewshot: Optional[int] = None + # scoring options + metric_list: Optional[list] = None + output_type: OutputType = "generate_until" + generation_kwargs: Optional[dict] = None + repeats: int = 1 + filter_list: Optional[Union[str, list]] = None + should_decontaminate: bool = False + doc_to_decontamination_query: Optional[str] = None + gen_prefix: Optional[str] = None + metadata: Optional[dict] = ( + None # by default, not used in the code. allows for users to pass arbitrary info to tasks + ) + + def __post_init__(self) -> None: + if self.generation_kwargs is not None: + if self.output_type != "generate_until": + eval_logger.warning( + f"[{self.task}] passed `generation_kwargs`, but not using `output_type: generate_until`!" + ) + + if "temperature" in self.generation_kwargs: + self.generation_kwargs["temperature"] = float( + self.generation_kwargs["temperature"] + ) + + if "until" not in self.generation_kwargs: + eval_logger.warning( + f"{self.task}: No `until` specified in `generation_kwargs`! Defaulting to the fewshot_delimiter={repr(self.fewshot_delimiter)}" + ) + self.generation_kwargs["until"] = [self.fewshot_delimiter] + else: + if self.output_type == "generate_until": + # ensure that we greedily generate in absence of explicit arguments otherwise + self.generation_kwargs = { + "until": ( + None + if self.fewshot_delimiter is None + else [self.fewshot_delimiter] + ), + "do_sample": False, + "temperature": 0, + } + eval_logger.warning( + f"{self.task}: No `generation_kwargs` specified in task config, defaulting to {self.generation_kwargs}" + ) + + def __getitem__(self, item): + return getattr(self, item) + + def __setitem__(self, item, value): + return setattr(self, item, value) + + def to_dict(self, keep_callable: bool = False) -> dict: + """dumps the current config as a dictionary object, as a printable format. + null fields will not be printed. + Used for dumping results alongside full task configuration + + :return: dict + A printable dictionary version of the TaskConfig object. + + # TODO: should any default value in the TaskConfig not be printed? + """ + cfg_dict = asdict(self) + # remove values that are `None` + for k, v in list(cfg_dict.items()): + if v is None: + cfg_dict.pop(k) + elif k == "metric_list": + for metric_dict in v: + for metric_key, metric_value in metric_dict.items(): + if callable(metric_value): + metric_dict[metric_key] = self.serialize_function( + metric_value, keep_callable=keep_callable + ) + cfg_dict[k] = v + elif callable(v): + cfg_dict[k] = self.serialize_function(v, keep_callable=keep_callable) + return cfg_dict + + def serialize_function( + self, value: Union[Callable, str], keep_callable=False + ) -> Union[Callable, str]: + """Serializes a given function or string. + + If 'keep_callable' is True, the original callable is returned. + Otherwise, attempts to return the source code of the callable using 'getsource'. + """ + if keep_callable: + return value + else: + try: + return getsource(value) + except (TypeError, OSError): + return str(value) + + +class Task(abc.ABC): + """A task represents an entire benchmark including its dataset, problems, + answers, and evaluation methods. See BoolQ for a simple example implementation + + A `doc` can be any python object which represents one instance of evaluation. + This is usually a dictionary e.g. + {"question": ..., "answer": ...} or + {"question": ..., question, answer) + """ + + VERSION: Optional[Union[int, str]] = None + + # The name of the `Task` benchmark as denoted in the HuggingFace datasets Hub + # or a path to a custom `datasets` loading script. + DATASET_PATH: Optional[str] = None + + # The name of a subset within `DATASET_PATH`. + DATASET_NAME: Optional[str] = None + + OUTPUT_TYPE: Optional[OutputType] = None + + def __init__( + self, + data_dir: Optional[str] = None, + cache_dir: Optional[str] = None, + download_mode: Optional[datasets.DownloadMode] = None, + config: Optional[Mapping] = None, # Union[dict, TaskConfig] + ) -> None: + """ + :param data_dir: str + Stores the path to a local folder containing the `Task`'s data files. + Use this to specify the path to manually downloaded data (usually when + the dataset is not publicly accessible). + :param cache_dir: str + The directory to read/write the `Task` dataset. This follows the + HuggingFace `datasets` API with the default cache directory located at: + `~/.cache/huggingface/datasets` + NOTE: You can change the cache location globally for a given process + to another directory: + `export HF_DATASETS_CACHE="/path/to/another/directory"` + :param download_mode: datasets.DownloadMode + How to treat pre-existing `Task` downloads and data. + - `datasets.DownloadMode.REUSE_DATASET_IF_EXISTS` + Reuse download and reuse dataset. + - `datasets.DownloadMode.REUSE_CACHE_IF_EXISTS` + Reuse download with fresh dataset. + - `datasets.DownloadMode.FORCE_REDOWNLOAD` + Fresh download and fresh dataset. + """ + self.download(data_dir, cache_dir, download_mode) + self._training_docs: Optional[list] = None + self._fewshot_docs: Optional[list] = None + self._instances: Optional[List[Instance]] = None + + self._config: TaskConfig = TaskConfig({**config}) if config else TaskConfig() + + self._filters = [build_filter_ensemble("none", [["take_first", None]])] + self.fewshot_rnd: Optional[random.Random] = ( + None # purposely induce errors in case of improper usage + ) + + def download( + self, + data_dir: Optional[str] = None, + cache_dir: Optional[str] = None, + download_mode=None, + ) -> None: + """Downloads and returns the task dataset. + Override this method to download the dataset from a custom API. + + :param data_dir: str + Stores the path to a local folder containing the `Task`'s data files. + Use this to specify the path to manually downloaded data (usually when + the dataset is not publicly accessible). + :param cache_dir: str + The directory to read/write the `Task` dataset. This follows the + HuggingFace `datasets` API with the default cache directory located at: + `~/.cache/huggingface/datasets` + NOTE: You can change the cache location globally for a given process + by setting the shell environment variable, `HF_DATASETS_CACHE`, + to another directory: + `export HF_DATASETS_CACHE="/path/to/another/directory"` + :param download_mode: datasets.DownloadMode + How to treat pre-existing `Task` downloads and data. + - `datasets.DownloadMode.REUSE_DATASET_IF_EXISTS` + Reuse download and reuse dataset. + - `datasets.DownloadMode.REUSE_CACHE_IF_EXISTS` + Reuse download with fresh dataset. + - `datasets.DownloadMode.FORCE_REDOWNLOAD` + Fresh download and fresh dataset. + """ + self.dataset = datasets.load_dataset( + path=self.DATASET_PATH, + name=self.DATASET_NAME, + data_dir=data_dir, + cache_dir=cache_dir, + download_mode=download_mode, + ) + + @property + def config(self) -> TaskConfig: + """Returns the TaskConfig associated with this class.""" + return self._config + + @abc.abstractmethod + def has_training_docs(self): + """Whether the task has a training set""" + pass + + @abc.abstractmethod + def has_validation_docs(self): + """Whether the task has a validation set""" + pass + + @abc.abstractmethod + def has_test_docs(self): + """Whether the task has a test set""" + pass + + def training_docs(self) -> Iterable: + """ + :return: Iterable[obj] + A iterable of any object, that doc_to_text can handle + """ + return [] + + def validation_docs(self) -> Iterable: + """ + :return: Iterable[obj] + A iterable of any object, that doc_to_text can handle + """ + return [] + + def test_docs(self) -> Iterable: + """ + :return: Iterable[obj] + A iterable of any object, that doc_to_text can handle + """ + return [] + + def fewshot_docs(self) -> Iterable: + """ + :return: Iterable[obj] + A iterable of any object, that doc_to_text can handle + """ + if self.has_training_docs(): + return self.training_docs() + elif self.has_validation_docs(): + return self.validation_docs() + else: + if self.config.get("num_fewshot", 0) > 0: + eval_logger.warning( + f"[Task: {self.config.task}] has_training_docs and has_validation_docs are False" + ", using test_docs as fewshot_docs but this is not recommended." + ) + return self.test_docs() + + def _process_doc(self, doc: dict) -> dict: + """ + Override this to process (detokenize, strip, replace, etc.) individual + documents. This can be used in a map over documents of a data split. + E.g. `map(self._process_doc, self.dataset["validation"])` + + :return: dict + The processed version of the specified `doc`. + """ + return doc + + @property + def instances(self) -> List[Instance]: + """After calling `task.build_all_requests()`, tasks + maintain a list of the dataset instances which will be evaluated. + """ + return self._instances + + def fewshot_examples(self, k, rnd): + if self._training_docs is None: + self._training_docs = list(self.training_docs()) + + return rnd.sample(self._training_docs, k) + + def doc_to_decontamination_query(self, doc): + raise NotImplementedError( + "Override doc_to_decontamination_query with document specific decontamination query." + ) + + @abc.abstractmethod + def doc_to_text(self, doc): + pass + + @abc.abstractmethod + def doc_to_target(self, doc): + pass + + # not an abstractmethod because not every language-only task has to implement this + def doc_to_image(self, doc): + raise NotImplementedError + + def doc_to_audio(self, doc): + raise NotImplementedError + + def doc_to_prefix(self, doc): + return "" + + def build_all_requests( + self, + *, + limit: Union[int, None] = None, + samples: Optional[List[int]] = None, + rank: int = 0, + world_size: int = 1, + cache_requests: bool = False, + rewrite_requests_cache: bool = False, + system_instruction: Optional[str] = None, + apply_chat_template: bool = False, + fewshot_as_multiturn: bool = False, + chat_template: Optional[Callable] = None, + tokenizer_name: str = "", + ) -> None: + """Build a set of Instances for a task, and store them in task.instances""" + + # used with caching + og_limit = limit + + cache_key = f"requests-{self._config.task}-{self.config.num_fewshot}shot-rank{rank}-world_size{world_size}" + cache_key += "-chat_template" if apply_chat_template else "" + cache_key += "-fewshot_as_multiturn" if fewshot_as_multiturn else "" + cache_key += ( + f"-system_prompt_hash{utils.hash_string(system_instruction)}" + if system_instruction is not None + else "" + ) + cache_key += f"-tokenizer{tokenizer_name}" + + cached_instances = load_from_cache(file_name=cache_key, cache=cache_requests) + + if cache_requests and cached_instances and not rewrite_requests_cache: + cached_instances = cached_instances[:limit] + + flattened_instances = [ + instance + for instance_group in cached_instances + for instance in instance_group + ] + + self._instances = flattened_instances + return + + eval_logger.info(f"Building contexts for {self.config.task} on rank {rank}...") + + instances = [] + + # process all documents when caching is specified for simplicity + if ( + cache_requests + and (not cached_instances or rewrite_requests_cache) + and limit is not None + ): + limit = None + + doc_id_docs = list( + self.doc_iterator( + rank=rank, limit=limit, samples=samples, world_size=world_size + ) + ) + + num_docs = len(doc_id_docs) + + for doc_id, doc in tqdm( + doc_id_docs, + total=num_docs, + ): + # sample fewshot context #TODO: need to offset doc_id by rank now! + fewshot_ctx = self.fewshot_context( + doc, + num_fewshot=0 + if self.config.num_fewshot is None + else self.config.num_fewshot, + system_instruction=system_instruction, + apply_chat_template=apply_chat_template, + fewshot_as_multiturn=fewshot_as_multiturn, + chat_template=chat_template, + gen_prefix=self.doc_to_prefix(doc), + ) + + # TODO: we should override self.config.repeats if doing greedy gen so users don't waste time+compute + inst = self.construct_requests( + doc=doc, + ctx=fewshot_ctx, + metadata=(self.config["task"], doc_id, self.config.repeats), + apply_chat_template=apply_chat_template, + chat_template=chat_template, + ) + + if not isinstance(inst, list): + inst = [inst] + + instances.append(inst) + + # now flatten, this is to allow slicing to work with pickles + + sliced_instances = instances[:og_limit] + + flattened_instances = [ + instance + for instance_group in sliced_instances + for instance in instance_group + ] + + self._instances = flattened_instances + + if len(self._instances) == 0: + raise ValueError("task.build_requests() did not find any docs!") + + if cache_requests and (not cached_instances or rewrite_requests_cache): + save_to_cache(file_name=cache_key, obj=instances) + + @abc.abstractmethod + def construct_requests(self, doc, ctx, **kwargs): + """Uses RequestFactory to construct Requests and returns an iterable of + Requests which will be sent to the LM. + + :param doc: + The document as returned from training_docs, validation_docs, or test_docs. + :param ctx: str + The context string, generated by fewshot_context. This includes the natural + language description, as well as the few shot examples, and the question + part of the document for `doc`. + :param doc_idx: int + The index of a document within `self.test_docs()` or `self.validation_docs()`, + whichever is the main split used. + :param repeats: int + TODO: update this docstring + The number of times each instance in a dataset is inferred on. Defaults to 1, + can be increased for techniques like majority voting. + """ + pass + + @abc.abstractmethod + def process_results(self, doc, results): + """Take a single document and the LM results and evaluates, returning a + dict where keys are the names of submetrics and values are the values of + the metric for that one document + + :param doc: + The document as returned from training_docs, validation_docs, or test_docs. + :param results: + The results of the requests created in construct_requests. + """ + pass + + @abc.abstractmethod + def aggregation(self): + """ + :returns: {str: [metric_score] -> float} + A dictionary where keys are the names of submetrics and values are + functions that aggregate a list of metric scores + """ + pass + + @abc.abstractmethod + def higher_is_better(self): + """ + :returns: {str: bool} + A dictionary where keys are the names of submetrics and values are + whether a higher value of the submetric is better + """ + pass + + def get_config(self, key: str) -> Any: + return getattr(self._config, key, None) + + @classmethod + def count_bytes(cls, doc): + """Used for byte-level perplexity metrics in rolling loglikelihood""" + return len(doc.encode("utf-8")) + + @classmethod + def count_words(cls, doc): + """Downstream loglikelihood_rolling perplexity tasks with custom word boundaries should override this!""" + return len(re.split(r"\s+", doc)) + + @utils.positional_deprecated + def fewshot_context(self, doc, num_fewshot, rnd=None, description=None, **kwargs): + """Returns a fewshot context string that is made up of a prepended description + (if provided), the `num_fewshot` number of examples, and an appended prompt example. + + :param doc: str + The document as returned from training_docs, validation_docs, or test_docs. + :param num_fewshot: int + The number of fewshot examples to provide in the returned context string. + :param rnd: random.Random + The pseudo-random number generator used to randomly sample examples. + WARNING: This is currently a required arg although it's optionalized with a default `None`. + :param description: str + The task's description that will be prepended to the fewshot examples. + :returns: str + The fewshot context. + """ + if rnd is None: + if self.fewshot_rnd is not None: + rnd = self.fewshot_rnd + else: + raise ValueError( + "A `random.Random` generator argument must be provided to `rnd`" + ) + + description = description if description else "" + + if num_fewshot == 0: + labeled_examples = "" + else: + # for sets with no training docs, draw from other set *but ensure no overlap with current doc* + if self.has_training_docs(): + fewshotex = self.fewshot_examples(k=num_fewshot, rnd=rnd) + else: + if self._fewshot_docs is None: + self._fewshot_docs = list( + self.validation_docs() + if self.has_validation_docs() + else self.test_docs() + ) + + fewshotex = rnd.sample(self._fewshot_docs, num_fewshot + 1) + + # get rid of the doc that's the one we're evaluating, if it's in the fewshot + fewshotex = [x for x in fewshotex if x != doc][:num_fewshot] + + labeled_examples = ( + "\n\n".join( + [ + self.doc_to_text(doc) + self.doc_to_target(doc) + for doc in fewshotex + ] + ) + + "\n\n" + ) + + example = self.doc_to_text(doc) + return description + labeled_examples + example + + def apply_filters(self) -> Optional[List[Instance]]: + """Iterates over FilterEnsembles and applies them to instances""" + if hasattr(self, "_filters"): + for f in self._filters: + f.apply(self._instances) + else: + eval_logger.warning("No filter defined, passing through instances") + return self._instances + + def dump_config(self) -> dict: + """Returns the config as a dictionary.""" + # TODO: this should only return the overrides applied to a non-YAML task's configuration. + # (num_fewshot) + return self.config.to_dict() + + def set_config(self, key: str, value: Any, update: bool = False) -> None: + """Set or update the configuration for a given key.""" + if key is None: + raise ValueError("Key must be provided.") + + if update: + current_value = getattr(self._config, key, {}) + if not isinstance(current_value, dict): + raise TypeError( + f"Expected a dict for key '{key}', got {type(current_value).__name__} instead." + ) + current_value.update(value) + else: + setattr(self._config, key, value) + + def override_metric(self, metric_name: str) -> None: + """ + Override the default metrics used for evaluation with custom metrics. + + Parameters: + - metric_name (str): The name of the custom metric to override. Should be registered in api.metrics. + """ + ( + self._metric_fn_list, + self._aggregation_list, + self._metric_fn_kwargs, + self._higher_is_better, + ) = ({}, {}, {}, {}) + self._metric_fn_list[metric_name] = get_metric(metric_name) + self._aggregation_list[metric_name] = get_metric_aggregation(metric_name) + self._higher_is_better[metric_name] = is_higher_better(metric_name) + self._metric_fn_kwargs[metric_name] = {} + if not isinstance(self, ConfigurableTask): + self.process_results = lambda x, y: {metric_name: get_metric(metric_name)} + self.aggregation = lambda: { + metric_name: get_metric_aggregation(metric_name) + } + setattr(self._config, "metric_list", [{"metric": metric_name}]) + setattr(self._config, "process_results", None) + + def set_fewshot_seed(self, seed: Optional[int] = None) -> None: + self.fewshot_rnd = random.Random(seed) + if hasattr(self, "sampler"): + self.sampler.rnd = self.fewshot_rnd + + @property + def eval_docs(self) -> Union[datasets.Dataset, List[dict]]: + if self.has_test_docs(): + return self.test_docs() + elif self.has_validation_docs(): + return self.validation_docs() + else: + raise ValueError( + f"Task dataset (path={self.DATASET_PATH}, name={self.DATASET_NAME}) must have valid or test docs!" + ) + + def doc_iterator( + self, + *, + rank: int = 0, + limit: Union[int, None] = None, + world_size: int = 1, + samples: Optional[List[int]] = None, + ) -> Iterator[Tuple[int, Any]]: + if samples: + n = len(self.eval_docs) + assert all([e < n for e in samples]), ( + f"Elements of --samples should be in the interval [0,k-1] where k is the number of total examples. In this case, k={n}." + ) + eval_logger.info( + f"{self.config.task}: Evaluating on {len(samples)} examples" + ) + doc_iterator = utils.create_iterator( + enumerate(x for i, x in enumerate(self.eval_docs) if i in samples), + rank=int(rank), + limit=None, # limit does not matter here since we are selecting samples directly + world_size=int(world_size), + ) + else: + limit = int(limit) if limit else None + doc_iterator = utils.create_iterator( + enumerate(self.eval_docs), + rank=int(rank), + limit=limit, + world_size=int(world_size), + ) + return doc_iterator + + +class ConfigurableTask(Task): + VERSION = "Yaml" + OUTPUT_TYPE = None + CONFIG = None + + def __init__( + self, + data_dir=None, + cache_dir=None, + download_mode=None, + config: Optional[dict] = None, + ) -> None: # TODO no super() call here + # Get pre-configured attributes + self._config = self.CONFIG + + # Use new configurations if there was no preconfiguration + if self.config is None: + self._config = TaskConfig(**config) + # Overwrite configs + else: + if config is not None: + self._config.__dict__.update(config) + + if self.config is None: + raise ValueError( + "Must pass a config to ConfigurableTask, either in cls.CONFIG or `config` kwarg" + ) + + if isinstance(self.config.metadata, dict): + if "version" in self.config.metadata: + self.VERSION = self.config.metadata["version"] + + if self.config.output_type is not None: + if self.config.output_type not in ALL_OUTPUT_TYPES: + raise ValueError( + f"Got invalid output_type '{self.config.output_type}', must be in '{','.join(ALL_OUTPUT_TYPES)}'" + ) + self.OUTPUT_TYPE = self.config.output_type + + if self.config.doc_to_image is not None: + # mark the task as requiring multimodality. + self.MULTIMODAL = True + + if self.config.doc_to_audio: + # mark the task as requiring multimodality. + self.MULTIMODAL = True + + if self.config.unsafe_code is not False: + self.UNSAFE_CODE = True + + if self.config.dataset_path is not None: + self.DATASET_PATH = self.config.dataset_path + + if self.config.dataset_name is not None: + self.DATASET_NAME = self.config.dataset_name + + self._metric_fn_list = {} + self._metric_fn_kwargs = {} + self._aggregation_list = {} + self._higher_is_better = {} + + if self.config.metric_list is None: + # TODO: handle this in TaskConfig.__post_init__ ? + _metric_list = DEFAULT_METRIC_REGISTRY[self.config.output_type] + + for metric_name in _metric_list: + self._metric_fn_list[metric_name] = get_metric(metric_name) + self._metric_fn_kwargs[metric_name] = {} + self._aggregation_list[metric_name] = get_metric_aggregation( + metric_name + ) + self._higher_is_better[metric_name] = is_higher_better(metric_name) + else: + for metric_config in self.config.metric_list: + if "metric" not in metric_config: + raise ValueError( + "'metric' key not provided for an entry in 'metric_list', must be specified!" + ) + metric_name = metric_config["metric"] + kwargs = { + key: metric_config[key] + for key in metric_config + if key + not in ["metric", "aggregation", "higher_is_better", "hf_evaluate"] + } + hf_evaluate_metric = ( + "hf_evaluate" in metric_config + and metric_config["hf_evaluate"] is True + ) + + if self.config.process_results is not None: + self._metric_fn_list[metric_name] = None + self._metric_fn_kwargs[metric_name] = {} + elif callable(metric_name): + metric_fn = metric_name.__call__ + metric_name = metric_name.__name__ + self._metric_fn_list[metric_name] = metric_fn + self._metric_fn_kwargs[metric_name] = kwargs + else: + self._metric_fn_list[metric_name] = get_metric( + metric_name, hf_evaluate_metric + ) + self._metric_fn_kwargs[metric_name] = kwargs + + if "aggregation" in metric_config: + agg_name = metric_config["aggregation"] + if isinstance(agg_name, str): + self._aggregation_list[metric_name] = get_aggregation(agg_name) + elif callable(agg_name): # noqa: E721 + self._aggregation_list[metric_name] = metric_config[ + "aggregation" + ] + else: + INV_AGG_REGISTRY = {v: k for k, v in AGGREGATION_REGISTRY.items()} + metric_agg = get_metric_aggregation(metric_name) + eval_logger.warning( + f"[Task: {self.config.task}] metric {metric_name} is defined, but aggregation is not. " + f"using default " + f"aggregation={INV_AGG_REGISTRY[metric_agg]}" + ) + self._aggregation_list[metric_name] = metric_agg + + if "higher_is_better" in metric_config: + self._higher_is_better[metric_name] = metric_config[ + "higher_is_better" + ] + else: + eval_logger.warning( + f"[Task: {self.config.task}] metric {metric_name} is defined, but higher_is_better is not. " + f"using default " + f"higher_is_better={is_higher_better(metric_name)}" + ) + self._higher_is_better[metric_name] = is_higher_better(metric_name) + + self.download(self.config.dataset_kwargs) + self._training_docs = None + self._fewshot_docs = None + + if self.config.filter_list is not None: + self._filters = [] + for filter_config in self.config.filter_list: + filter_name = filter_config["name"] + filter_functions = filter_config["filter"] + components = [] + for function in filter_functions: + kwargs = { + key: function[key] for key in function if key != "function" + } + components.append([function["function"], kwargs]) + filter_pipeline = build_filter_ensemble(filter_name, components) + self._filters.append(filter_pipeline) + else: + # TODO: handle repeats in a more general way rather than just discarding + eval_logger.debug( + "No custom filters defined. Using default 'take_first' filter for handling repeats." + ) + self._filters = [build_filter_ensemble("none", [["take_first", None]])] + + if self.config.use_prompt is not None: + eval_logger.info(f"loading prompt {self.config.use_prompt}") + self.prompt = get_prompt( + self.config.use_prompt, self.DATASET_PATH, self.DATASET_NAME + ) + else: + self.prompt = None + + if self.fewshot_docs() is not None: + self.fewshot_rnd = ( + random.Random() + ) # setting with no seed, to be overridden at a later time + config_sampler: Union[str, Callable] = ( + self.config.fewshot_config.get("sampler", "default") + if self.config.fewshot_config + else "default" + ) + if isinstance(config_sampler, str): + self.sampler = samplers.get_sampler(config_sampler)( + list(self.fewshot_docs()), self, rnd=self.fewshot_rnd + ) + elif callable(config_sampler) and issubclass( + config_sampler, samplers.ContextSampler + ): + self.sampler = config_sampler( + docs=list(self.fewshot_docs()), task=self, rnd=self.fewshot_rnd + ) + else: + raise TypeError( + f"fewshot_config.sampler should be a string or callable of ContextSampler type, " + f"not {type(config_sampler)}" + ) + + self.task_docs = self.eval_docs + + # Test One Doc + self.features = list(self.task_docs.features.keys()) + self.multiple_input = 0 + self.multiple_target = 0 + test_doc = self.task_docs[0] + test_text = self.doc_to_text(test_doc) + test_target = self.doc_to_target(test_doc) + + if self.config.doc_to_choice is not None: + test_choice = self.doc_to_choice(test_doc) + if not isinstance(test_choice, list): + eval_logger.error("doc_to_choice must return list") + else: + num_choice = len(test_choice) + + if isinstance(test_text, int): + eval_logger.debug( + "doc_to_text returned an int. Assuming multiple inputs." + ) + self.multiple_input = num_choice + else: + test_choice = None + + if isinstance(test_target, list): + eval_logger.debug( + "doc_to_target returned a list. Assuming multiple targets." + ) + self.multiple_target = len(test_target) + else: + if (isinstance(test_target, int)) and (test_choice is not None): + test_target = test_choice[test_target] + else: + test_target = str(test_target) + + if test_choice is not None: + check_choices = test_choice + else: + check_choices = [test_target] + if self.config.doc_to_choice is not None: + for choice in check_choices: + choice_has_whitespace = True if choice[0].isspace() else False + delimiter_has_whitespace = ( + True + if self.config.target_delimiter.rstrip() + != self.config.target_delimiter + else False + ) + + if delimiter_has_whitespace and choice_has_whitespace: + eval_logger.debug( + f'Both target_delimiter "{self.config.target_delimiter}" and target choice: "{choice}" have whitespace' + ) + elif (not delimiter_has_whitespace) and (not choice_has_whitespace): + eval_logger.debug( + f'Both target_delimiter "{self.config.target_delimiter}" and target choice: "{choice}" do not have whitespace, ignore if the language you are evaluating on does not require/use whitespace' + ) + + def download( + self, dataset_kwargs: Optional[Dict[str, Any]] = None, **kwargs + ) -> None: + if isinstance(self.config.custom_dataset, Callable): + eval_logger.warning( + f"{self.config.task}: Custom kwargs can be passed to `--metadata` in console (as json string) or to the TaskManager." + + "\nFor example --metadata='{\"max_seq_lengths\":[4096, 8192]}'. For details see task Readme." + ) + self.dataset = self.config.custom_dataset( + **(self.config.metadata or {}), **(self.config.dataset_kwargs or {}) + ) + else: + self.dataset = datasets.load_dataset( + path=self.DATASET_PATH, + name=self.DATASET_NAME, + **dataset_kwargs if dataset_kwargs is not None else {}, + ) + + def has_training_docs(self) -> bool: + if self.config.training_split is not None: + return True + else: + return False + + def has_validation_docs(self) -> bool: + if self.config.validation_split is not None: + return True + else: + return False + + def has_test_docs(self) -> bool: + if self.config.test_split is not None: + return True + else: + return False + + def training_docs(self) -> datasets.Dataset: + if self.has_training_docs(): + if self.config.process_docs is not None: + return self.config.process_docs( + self.dataset[self.config.training_split] + ) + return self.dataset[self.config.training_split] + + def validation_docs(self) -> datasets.Dataset: + if self.has_validation_docs(): + if self.config.process_docs is not None: + return self.config.process_docs( + self.dataset[self.config.validation_split] + ) + return self.dataset[self.config.validation_split] + + def test_docs(self) -> datasets.Dataset: + if self.has_test_docs(): + if self.config.process_docs is not None: + return self.config.process_docs(self.dataset[self.config.test_split]) + return self.dataset[self.config.test_split] + + def fewshot_docs(self): + if self.config.fewshot_split is not None: + if self.config.process_docs is not None: + return self.config.process_docs(self.dataset[self.config.fewshot_split]) + return self.dataset[self.config.fewshot_split] + elif ( + self.config.fewshot_config is not None + and self.config.fewshot_config.get("samples", None) is not None + ): + if isinstance(self.config.fewshot_config["samples"], list): + return self.config.fewshot_config["samples"] + elif callable(self.config.fewshot_config["samples"]): + return self.config.fewshot_config["samples"]() + else: + raise Exception( + "`fewshot_config['samples']` was incorrectly defined in the configuration. It should be either a list of samples as a dict, or function returning this list." + ) + else: + if (self.config.num_fewshot is not None) and (self.config.num_fewshot > 0): + eval_logger.warning( + f"[Task: {self.config.task}] " + "num_fewshot > 0 but fewshot_split is None. " + "using preconfigured rule." + ) + return super().fewshot_docs() + + @staticmethod + def append_target_question( + labeled_examples: List[Dict[str, str]], + question: str, + fewshot_as_multiturn: bool = False, + gen_prefix: Optional[str] = None, + ) -> None: + """Adds a target question to the labeled examples list. + If fewshot_as_multiturn is True, or labeled_examples is empty, or the last entry is a system turn, appends the question as a new user entry. + Otherwise, it is appended to the last user entry, ensuring that the conversation alternates between the user and the assistant. + """ + if not fewshot_as_multiturn: + # if no messages or last message is system, append as new user entry + if len(labeled_examples) == 0 or labeled_examples[-1]["role"] == "system": + labeled_examples.append({"role": "user", "content": question}) + # if last message is user, append to it to avoid two user messages in a row + else: + labeled_examples[-1]["content"] += question + else: + # if fewshot_as_multiturn is True, append as next user entry (last is always assistant) + labeled_examples.append({"role": "user", "content": question}) + if gen_prefix: + labeled_examples.append({"role": "assistant", "content": gen_prefix}) + + @utils.positional_deprecated + def fewshot_context( + self, + doc: dict, + num_fewshot: int, + system_instruction: Optional[str] = None, + apply_chat_template: bool = False, + fewshot_as_multiturn: bool = False, + chat_template: Optional[Callable] = None, + gen_prefix: Optional[str] = None, + ) -> Union[str, List[str]]: + """Returns a fewshot context string that is made up of a prepended description + (if provided), the `num_fewshot` number of examples, and an appended prompt example. + + :param doc: str + The document as returned from training_docs, validation_docs, or test_docs. + :param num_fewshot: int + The number of fewshot examples to provide in the returned context string. + :param system_instruction: str + System instruction to be applied to the prompt. + :param apply_chat_template: bool + Whether to apply the chat template to the fewshot context. + :param fewshot_as_multiturn: bool + Whether to provide the fewshot examples as a multiturn conversation or a single user turn. + :param chat_template: + callable (from lm.apply_chat_template) that takes in a list[Dict] chat transcript and renders it into a string. + :param gen_prefix: + String to append after the <|assistant|> token. + :returns: str + The fewshot context. + """ + if apply_chat_template: + labeled_examples = [] + else: + labeled_examples = "" + + # get task description + if description := self.config.description: + description = utils.apply_template(self.config.description, doc) + + # create system prompt based on the provided system instruction and description + if system_instruction is not None and description: + system_prompt = ( + f"{system_instruction}{self.sampler.fewshot_delimiter}{description}" + ) + elif system_instruction is not None: + system_prompt = system_instruction + elif description: + system_prompt = description + else: + system_prompt = "" + + # add system prompt if specified + if system_prompt: + if apply_chat_template: + labeled_examples.append({"role": "system", "content": system_prompt}) + else: + labeled_examples = system_prompt + # if few-shot - append examples after the system prompt + if num_fewshot > 0: + if apply_chat_template: + labeled_examples.extend( + self.sampler.get_chat_context( + doc, + num_fewshot, + fewshot_as_multiturn, + gen_prefix=gen_prefix, + ) + ) + else: + labeled_examples += self.sampler.get_context( + doc, num_fewshot, gen_prefix=gen_prefix + ) + + example = self.doc_to_text(doc) + if apply_chat_template: + if self.multiple_input: + # TODO: append prefill? + if not labeled_examples: + return "" + return chat_template(labeled_examples) + if isinstance(example, str): + self.append_target_question( + labeled_examples, + example, + fewshot_as_multiturn, + gen_prefix=gen_prefix, + ) + # for loglikelihood create a list of questions with appended choices + elif isinstance(example, list): + labeled_examples_list = [] + # copy chat history for each example and append the answer + for ex in example: + chat = deepcopy(labeled_examples) + self.append_target_question( + chat, + ex, + fewshot_as_multiturn, + gen_prefix=gen_prefix, + ) + # TODO: append prefill? + labeled_examples_list.append( + chat_template( + chat, + add_generation_prompt=False if gen_prefix else True, + ) + ) + return labeled_examples_list + # if example is an integer, append the choice or convert to string + elif isinstance(example, int): + if self.config.doc_to_choice is not None: + choices = self.doc_to_choice(doc) + self.append_target_question( + labeled_examples, + choices[example], + fewshot_as_multiturn, + gen_prefix=gen_prefix, + ) + else: + self.append_target_question( + labeled_examples, + str(example), + fewshot_as_multiturn, + gen_prefix=gen_prefix, + ) + # return lm.apply_chat_template(labeled_examples) + return chat_template( + labeled_examples, + add_generation_prompt=False if gen_prefix else True, + ) + else: + prefix = ( + self.config.target_delimiter + gen_prefix + if gen_prefix is not None + else "" + ) + if self.multiple_input: + return labeled_examples + if isinstance(example, str): + return labeled_examples + example + prefix + elif isinstance(example, list): + return [labeled_examples + ex + prefix for ex in example] + elif isinstance(example, int): + if self.config.doc_to_choice is not None: + choices = self.doc_to_choice(doc) + return labeled_examples + choices[example] + prefix + else: + return labeled_examples + str(example) + prefix + + def apply_filters(self) -> Optional[List[Instance]]: + """Iterates over FilterEnsembles and applies them to instances""" + if hasattr(self, "_filters"): + for f in self._filters: + f.apply(self._instances) + else: + eval_logger.warning("No filter defined, passing through instances") + return self._instances + + def should_decontaminate(self): + return self.config.should_decontaminate + + def doc_to_decontamination_query(self, doc: dict): + if self.config.should_decontaminate: + if self.config.doc_to_decontamination_query is None: + return self.doc_to_text(doc) + else: + doc_to_decontamination_query = self.config.doc_to_decontamination_query + if doc_to_decontamination_query in self.features: + return doc[doc_to_decontamination_query] + elif callable(doc_to_decontamination_query): + return doc_to_decontamination_query(doc) + else: + return ast.literal_eval( + utils.apply_template( + self.config.doc_to_decontamination_query, doc + ) + ) + + def _process_doc(self, doc: dict) -> dict: + """ + Override this to process (detokenize, strip, replace, etc.) individual + documents. This can be used in a map over documents of a data split. + E.g. `map(self._process_doc, self.dataset["validation"])` + + :return: dict + The processed version of the specified `doc`. + """ + return doc + + def doc_to_text(self, doc, doc_to_text=None): + if self.prompt is not None: + doc_to_text = self.prompt + elif doc_to_text is not None: + doc_to_text = doc_to_text + else: + doc_to_text = self.config.doc_to_text + + if isinstance(doc_to_text, int): + return doc_to_text + elif isinstance(doc_to_text, str): + if doc_to_text in self.features: + # if self.config.doc_to_choice is not None: + # return self.doc_to_choice(doc)[doc[doc_to_text]] + # else: + return doc[doc_to_text] + else: + text_string = utils.apply_template(doc_to_text, doc) + if text_string.isdigit() and self._config.doc_to_choice is not None: + return ast.literal_eval(text_string) + else: + return text_string + elif callable(doc_to_text): + return doc_to_text(doc) + # Used when applying a Promptsource template + elif hasattr(doc_to_text, "apply"): + applied_prompt = doc_to_text.apply(doc) + if len(applied_prompt) == 2: + return applied_prompt[0] + else: + eval_logger.warning("Applied prompt returns empty string") + return self.config.fewshot_delimiter + else: + print(type(doc_to_text)) + raise TypeError + + def doc_to_target(self, doc: Mapping, doc_to_target=None) -> Union[int, str, list]: + if self.prompt is not None: + doc_to_target = self.prompt + elif doc_to_target is not None: + doc_to_target = doc_to_target + else: + doc_to_target = self.config.doc_to_target + + if isinstance(doc_to_target, int): + return doc_to_target + elif isinstance(doc_to_target, str): + if doc_to_target in self.features: + # if self.config.doc_to_choice is not None: + # return self.doc_to_choice(doc)[doc[doc_to_target]] + # else: + return doc[doc_to_target] + else: + target_string = utils.apply_template(doc_to_target, doc) + if target_string.isdigit() and self._config.doc_to_choice is not None: + return ast.literal_eval(target_string) + elif ( + len(target_string) >= 2 + and (target_string[0] == "[") + and (target_string[-1] == "]") + ): + try: + return ast.literal_eval(target_string) + except (SyntaxError, ValueError): + return target_string + else: + return target_string + elif isinstance(doc_to_target, list): + return doc_to_target + elif callable(doc_to_target): + return doc_to_target(doc) + # Used when applying a Promptsource template + elif hasattr(doc_to_target, "apply"): + applied_prompt = doc_to_target.apply(doc) + if len(applied_prompt) == 2: + return applied_prompt[1] + else: + eval_logger.warning("Applied prompt returns empty string") + return self.config.fewshot_delimiter + else: + raise TypeError + + def doc_to_choice(self, doc: Any, doc_to_choice=None) -> List[str]: + if self.prompt is not None: + doc_to_choice = self.prompt + elif doc_to_choice is not None: + doc_to_choice = doc_to_choice + elif self.config.doc_to_choice is None: + eval_logger.error("doc_to_choice was called but not set in config") + else: + doc_to_choice = self.config.doc_to_choice + + if isinstance(doc_to_choice, str): + if doc_to_choice in self.features: + return doc[doc_to_choice] + else: + return ast.literal_eval(utils.apply_template(doc_to_choice, doc)) + elif isinstance(doc_to_choice, list): + return doc_to_choice + elif isinstance(doc_to_choice, dict): + return list(doc_to_choice.values()) + elif callable(doc_to_choice): + return doc_to_choice(doc) + elif hasattr(doc_to_choice, "get_answer_choices_list"): + return doc_to_choice.get_answer_choices_list(doc) + else: + raise TypeError + + def doc_to_image(self, doc: Any, doc_to_image=None) -> Union[int, str, list]: + if doc_to_image is not None: + doc_to_image = doc_to_image + elif self.config.doc_to_image is not None: + doc_to_image = self.config.doc_to_image + else: + return None + + if isinstance(doc_to_image, list): + image_feature = [ + self.doc_to_image(doc, feature) for feature in doc_to_image + ] + return [feature for feature in image_feature if feature is not None] + elif isinstance(doc_to_image, str): + if doc_to_image in self.features: + return doc[doc_to_image] + else: + return ast.literal_eval(utils.apply_template(doc_to_image, doc)) + elif callable(doc_to_image): + return doc_to_image(doc) + else: + return None + + def doc_to_audio(self, doc: Any, doc_to_audio=None) -> Union[int, str, list]: + if doc_to_audio is not None: + doc_to_audio = doc_to_audio + elif self.config.doc_to_audio is not None: + doc_to_audio = self.config.doc_to_audio + else: + return None + + if isinstance(doc_to_audio, list): + audio_feature = [ + self.doc_to_audio(doc, feature) for feature in doc_to_audio + ] + return [feature for feature in audio_feature if feature is not None] + elif isinstance(doc_to_audio, str): + if doc_to_audio in self.features: + return doc[doc_to_audio] + else: + return ast.literal_eval(utils.apply_template(doc_to_audio, doc)) + elif callable(doc_to_audio): + return doc_to_audio(doc) + else: + return None + + def doc_to_prefix(self, doc): + if (gen_prefix := self.config.gen_prefix) is not None: + if gen_prefix in self.features: + return doc[gen_prefix] + else: + return utils.apply_template(gen_prefix, doc) + return None + + def construct_requests( + self, doc: dict, ctx: str, **kwargs + ) -> Union[List[Instance], Instance]: + apply_chat_template = kwargs.pop("apply_chat_template", False) + chat_template: Callable | None = kwargs.pop("chat_template", None) + + aux_arguments = None + + if self.OUTPUT_TYPE == "loglikelihood": + arguments = (ctx, self.doc_to_target(doc)) + elif self.OUTPUT_TYPE == "loglikelihood_rolling": + arguments = (self.doc_to_target(doc),) + elif self.OUTPUT_TYPE == "multiple_choice": + choices = self.doc_to_choice(doc) + target_delimiter = self.config.target_delimiter + if apply_chat_template: + target_delimiter = "" + if self.multiple_input: + # If there are multiple inputs, choices are placed in the ctx + # apply chat_template to choices if apply_chat_template + cont = self.doc_to_target(doc) + + arguments = [ + ( + ctx + + ( + chat_template([{"role": "user", "content": choice}]) + if apply_chat_template + else choice + ), + f"{target_delimiter}{cont}", + ) + for choice in choices + ] + else: + # Otherwise they are placed in the continuation + arguments = [(ctx, f"{target_delimiter}{cont}") for cont in choices] + + # TODO: we should raise a warning telling users this will at most ~2x runtime. + if "acc_mutual_info" in self._metric_fn_list.keys(): + # if we are calculating multiple choice accuracy + # using mutual information instead of raw loglikelihood as metric, need unconditional lls. + + # here mutual info refers to calculating + # log(P(choice|ctx) / P(choice)) = log(P(choice|ctx)) - log(P(choice)) + # in other words normalizing by subtracting the unconditional logprob of each choice. + # TODO: should these be strided? will have to modify the processing in process_results if so + aux_arguments = [ + ("", f"{target_delimiter}{choice}") for choice in choices + ] + + arguments.extend(aux_arguments) + + elif self.OUTPUT_TYPE == "generate_until": + arguments = (ctx, deepcopy(self.config.generation_kwargs)) + + multimodal_arg = {} + if ( + self.config.doc_to_image + ): # TODO: ensure that non-multimodal tasks aren't getting visual args + multimodal_arg = { + **multimodal_arg, + **{"visual": self.doc_to_image(doc)}, + } + + if ( + self.config.doc_to_audio + ): # TODO: ensure that non-multimodal tasks aren't getting audio args + multimodal_arg = { + **multimodal_arg, + **{"audio": self.doc_to_audio(doc)}, + } + + if bool(multimodal_arg): + if isinstance(arguments, list): + arguments = [arg + (multimodal_arg,) for arg in arguments] + else: + arguments = arguments + (multimodal_arg,) + + if self.OUTPUT_TYPE == "multiple_choice": + request_list = [ + Instance( + request_type="loglikelihood", + doc=doc, + arguments=arg, + idx=i, + **kwargs, + ) + for i, arg in enumerate(arguments) + ] + + return request_list + + return Instance( + request_type=self.OUTPUT_TYPE, + doc=doc, + arguments=arguments, + idx=0, + **kwargs, + ) + + def process_results(self, doc, results): + if callable(self.config.process_results): + return self.config.process_results(doc, results) + + result_dict = {} + use_metric = list(self._metric_fn_list.keys()) + if self.OUTPUT_TYPE == "loglikelihood": + results = results[0] + ll, is_greedy = results + return { + **({"perplexity": ll} if "perplexity" in use_metric else {}), + **({"acc": int(is_greedy)} if "acc" in use_metric else {}), + } + elif self.OUTPUT_TYPE == "loglikelihood_rolling": + (loglikelihood,) = results + _words = self.count_words(self.doc_to_target(doc)) + _bytes = self.count_bytes(self.doc_to_target(doc)) + return { + **( + {"word_perplexity": (loglikelihood, _words)} + if "word_perplexity" in use_metric + else {} + ), + **( + {"byte_perplexity": (loglikelihood, _bytes)} + if "byte_perplexity" in use_metric + else {} + ), + **( + {"bits_per_byte": (loglikelihood, _bytes)} + if "bits_per_byte" in use_metric + else {} + ), + } + elif self.OUTPUT_TYPE == "multiple_choice": + lls, is_greedy = zip(*results) + + # retrieve choices in List[str] form, to compute choice lengths, etc. + choices = self.doc_to_choice(doc) + completion_len = np.array([float(len(i)) for i in choices]) + + if ( + 2 * len(choices) == len(lls) + and "acc_mutual_info" in self._metric_fn_list.keys() + ): + # then we are doing mutual info. + # this stores the "dryrun" / unconditional answer loglikelihoods + # as we extend the args list with unconditional ("", continuation) pairs + lls_unconditional = lls[len(choices) :] + if len(lls_unconditional) != len(choices): + raise ValueError + # and this stores our "regular" conditional loglikelihoods + lls = lls[: len(choices)] + + pred = np.argmax(lls) + pred_norm = np.argmax(lls / completion_len) + + if self.multiple_input: + gold = self.doc_to_text(doc) + else: + gold = self.doc_to_target(doc) + + gold_index_error = False + if isinstance(gold, list): + gold = [i if i < len(choices) else -100 for i in gold] + if -100 in gold: + gold_index_error = True + else: + if isinstance(gold, int): + gold = gold if gold < len(choices) else -100 + elif isinstance(gold, str): + gold = choices.index(gold) if gold in choices else -100 + + if gold == -100: + gold_index_error = True + + if gold_index_error: + eval_logger.warning( + f"Label index was not in within range of available choices," + f"Sample:\n\n{doc}\n\n" + ) + + if self.multiple_target: + acc = 1.0 if pred in gold else 0.0 + acc_norm = 1.0 if pred_norm in gold else 0.0 + exact_match = int(any([is_greedy[i] if i != -100 else 0 for i in gold])) + else: + acc = 1.0 if pred == gold else 0.0 + acc_norm = 1.0 if pred_norm == gold else 0.0 + # TODO: this gets score of 0 on arc_challenge for pythia-70m. need to test that this works properly + exact_match = int(is_greedy[gold]) if gold != -100 else 0 + + prob_norm = utils.softmax(lls) + + # TODO use keyword arguments to the metric? + # gold, pred, norm stuff, the original lls, + result_dict = { + **({"acc": acc} if "acc" in use_metric else {}), + **({"f1": (gold, pred)} if "f1" in use_metric else {}), + **({"mcc": (gold, pred)} if "mcc" in use_metric else {}), + **({"acc_norm": acc_norm} if "acc_norm" in use_metric else {}), + **({"exact_match": exact_match} if "exact_match" in use_metric else {}), + **( + {"brier_score": (gold, prob_norm)} + if "brier_score" in use_metric + else {} + ), + } + + if "acc_mutual_info" in use_metric: + lls_mutual_info = [ + ll_c - ll_u for ll_c, ll_u in zip(lls, lls_unconditional) + ] + acc_mutual_info = 1.0 if np.argmax(lls_mutual_info) == gold else 0.0 + result_dict["acc_mutual_info"] = acc_mutual_info + + elif self.OUTPUT_TYPE == "generate_until": + gold = self.doc_to_target(doc) + result = results[0] + if self.config.doc_to_choice is not None: + # If you set doc_to_choice, + # it assumes that doc_to_target returns a number. + choices = self.doc_to_choice(doc) + gold = choices[gold] + # we expect multiple_targets to be a list. + elif self.multiple_target: + gold = list(gold) + # TODO: handle this better + elif type(gold) is not type(result) and not ( + "bypass" in self._metric_fn_list.keys() or isinstance(result, list) + ): + # cast gold to the same type as result + gold = type(result)(gold) + + for metric in self._metric_fn_list.keys(): + if self.multiple_target: + # in the case where we have multiple targets, + # return true if any are true + # TODO: this may break for multipLe_target, non zero-or-1 metrics + scores = [] + if not isinstance(gold, list): + # sometimes, a multiple_target dataset has exceptions where one doc has only one string answer + # print(gold) + gold = [gold] + if metric == "exact_match": + result = [result for _ in range(len(gold))] + scores = self._metric_fn_list[metric]( + references=gold, + predictions=result, + **self._metric_fn_kwargs[metric], + )[metric] + result_score = 1.0 if scores > 0.0 else 0.0 + else: + for gold_option in gold: + try: + result_score = self._metric_fn_list[metric]( + references=[gold_option], + predictions=[result], + **self._metric_fn_kwargs[metric], + ) + except ( + TypeError + ): # TODO: this is hacky and I don't want to do it + result_score = self._metric_fn_list[metric]( + [gold_option, result] + ) + if isinstance(result_score, dict): + # TODO: this handles the case where HF evaluate returns a dict. + result_score = result_score[metric] + scores.append(result_score) + if any(scores): + result_score = 1.0 + else: + result_score = 0.0 + else: + try: + result_score = self._metric_fn_list[metric]( + references=[gold], + predictions=[result], + **self._metric_fn_kwargs[metric], + ) + except TypeError: # needed for now in order to use a different interface between our own metrics and HF Evaluate metrics + result_score = self._metric_fn_list[metric]([gold, result]) + if isinstance(result_score, dict): + # TODO: this handles the case where HF evaluate returns a dict. + # This allows for multiple metrics to be returned from the same function + for k, v in result_score.items(): + result_dict[k] = v + else: + result_dict[metric] = result_score + else: + raise ValueError( + f"Passed invalid output_type '{self.OUTPUT_TYPE}' ! Please use one of ", + "'loglikelihood', 'loglikelihood_rolling', 'generate_until' or 'multiple_choice'", + ) + + return result_dict + + def aggregation(self) -> dict: + return self._aggregation_list + + def higher_is_better(self) -> dict: + return self._higher_is_better + + def get_config(self, key: str) -> Any: + return getattr(self._config, key, None) + + @property + def task_name(self) -> Any: + return getattr(self.config, "task", None) + + def __repr__(self): + return ( + f"ConfigurableTask(task_name={getattr(self.config, 'task', None)}," + f"output_type={self.OUTPUT_TYPE}," + f"num_fewshot={getattr(self.config, 'num_fewshot', None)}," + f"num_samples={len(self.eval_docs)})" + ) + + +class MultipleChoiceTask(Task): + OUTPUT_TYPE = "loglikelihood" + + def doc_to_target(self, doc: dict) -> str: + return " " + doc["choices"][doc["gold"]] + + def construct_requests(self, doc: dict, ctx: str, **kwargs) -> List[Instance]: + # TODO: add mutual info here? + return [ + Instance( + request_type="loglikelihood", + doc=doc, + arguments=(ctx, " {}".format(choice)), + idx=i, + **kwargs, + ) + for i, choice in enumerate(doc["choices"]) + ] + + def process_results(self, doc: dict, results: Iterable[Tuple[float, bool]]) -> dict: + results = [ + res[0] for res in results + ] # only retain loglikelihoods, discard is_greedy TODO: do we need is_greedy anywhere? + gold = doc["gold"] + + acc = 1.0 if np.argmax(results) == gold else 0.0 + completion_len = np.array([float(len(i)) for i in doc["choices"]]) + acc_norm = 1.0 if np.argmax(results / completion_len) == gold else 0.0 + + return { + "acc": acc, + "acc_norm": acc_norm, + } + + def higher_is_better(self) -> dict: + return { + "acc": True, + "acc_norm": True, + } + + def aggregation(self) -> dict: + return { + "acc": mean, + "acc_norm": mean, + } + + +class PerplexityTask(Task): + OUTPUT_TYPE = "loglikelihood_rolling" + + def has_training_docs(self) -> bool: + return False + + def fewshot_examples(self, k: int, rnd) -> List: + if k != 0: + raise ValueError( + "The number of fewshot examples must be 0 for perplexity tasks." + ) + return [] + + def fewshot_context(self, doc: dict, num_fewshot: int) -> Literal[""]: + if num_fewshot != 0: + raise ValueError( + "The number of fewshot examples must be 0 for perplexity tasks." + ) + + return "" + + def higher_is_better(self) -> dict: + return { + "word_perplexity": False, + "byte_perplexity": False, + "bits_per_byte": False, + } + + def doc_to_decontamination_query(self, doc): + return doc + + def doc_to_text(self, doc) -> str: + return "" + + def doc_to_target(self, doc): + return doc + + def construct_requests(self, doc: dict, ctx: Optional[str], **kwargs): + if bool(ctx): + raise ValueError + + return Instance( + request_type=self.OUTPUT_TYPE, + doc=doc, + arguments=(self.doc_to_target(doc),), + idx=0, + **kwargs, + ) + + def process_results(self, doc: dict, results: Tuple[float]) -> dict: + (loglikelihood,) = results + words = self.count_words(self.doc_to_target(doc)) + bytes_ = self.count_bytes(self.doc_to_target(doc)) + return { + "word_perplexity": (loglikelihood, words), + "byte_perplexity": (loglikelihood, bytes_), + "bits_per_byte": (loglikelihood, bytes_), + } + + def aggregation(self) -> dict: + return { + "word_perplexity": weighted_perplexity, + "byte_perplexity": weighted_perplexity, + "bits_per_byte": bits_per_byte, + } + + @classmethod + def count_bytes(cls, doc) -> int: + return len(doc.encode("utf-8")) + + @classmethod + def count_words(cls, doc) -> int: + """Downstream tasks with custom word boundaries should override this!""" + return len(re.split(r"\s+", doc)) diff --git a/Prism/LLaDA/LLaDA_Baseline/dllm_eval/caching/__init__.py b/Prism/LLaDA/LLaDA_Baseline/dllm_eval/caching/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Prism/LLaDA/LLaDA_Baseline/dllm_eval/caching/cache.py b/Prism/LLaDA/LLaDA_Baseline/dllm_eval/caching/cache.py new file mode 100644 index 0000000000000000000000000000000000000000..f8d293b0ff8b1ebac186f5ac078cdb49227562db --- /dev/null +++ b/Prism/LLaDA/LLaDA_Baseline/dllm_eval/caching/cache.py @@ -0,0 +1,59 @@ +import hashlib +import logging +import os + +import dill + + +eval_logger = logging.getLogger(__name__) + + +MODULE_DIR = os.path.dirname(os.path.realpath(__file__)) + +OVERRIDE_PATH = os.getenv("LM_HARNESS_CACHE_PATH") + + +PATH = OVERRIDE_PATH if OVERRIDE_PATH else f"{MODULE_DIR}/.cache" + +# This should be sufficient for uniqueness +HASH_INPUT = "EleutherAI-lm-evaluation-harness" + +HASH_PREFIX = hashlib.sha256(HASH_INPUT.encode("utf-8")).hexdigest() + +FILE_SUFFIX = f".{HASH_PREFIX}.pickle" + + +def load_from_cache(file_name: str, cache: bool = False): + if not cache: + return + try: + path = f"{PATH}/{file_name}{FILE_SUFFIX}" + + with open(path, "rb") as file: + cached_task_dict = dill.loads(file.read()) + return cached_task_dict + + except Exception: + eval_logger.debug(f"{file_name} is not cached, generating...") + pass + + +def save_to_cache(file_name, obj): + if not os.path.exists(PATH): + os.mkdir(PATH) + + file_path = f"{PATH}/{file_name}{FILE_SUFFIX}" + + eval_logger.debug(f"Saving {file_path} to cache...") + with open(file_path, "wb") as file: + file.write(dill.dumps(obj)) + + +# NOTE the "key" param is to allow for flexibility +def delete_cache(key: str = ""): + files = os.listdir(PATH) + + for file in files: + if file.startswith(key) and file.endswith(FILE_SUFFIX): + file_path = f"{PATH}/{file}" + os.unlink(file_path) diff --git a/Prism/LLaDA/LLaDA_Baseline/dllm_eval/decontamination/__init__.py b/Prism/LLaDA/LLaDA_Baseline/dllm_eval/decontamination/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Prism/LLaDA/LLaDA_Baseline/dllm_eval/decontamination/janitor.py b/Prism/LLaDA/LLaDA_Baseline/dllm_eval/decontamination/janitor.py new file mode 100644 index 0000000000000000000000000000000000000000..cedf8a5717aa8156674836ba236fdcabf36e0487 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Baseline/dllm_eval/decontamination/janitor.py @@ -0,0 +1,328 @@ +import pickle +import re +import string +import traceback +from typing import Iterator, List, Sequence, Tuple, TypeVar + + +# This is a cpp module. Compile janitor_util.cpp with: +# c++ -O3 -Wall -shared -std=c++11 -fPIC $(python3 -m pybind11 --includes) janitor_util.cpp -o janitor_util$(python3-config --extension-suffix) -undefined dynamic_lookup +try: + import janitor_util + + JANITOR_CPP = True +except Exception: + print("WARNING: C++ module could not be loaded. Janitor running in python mode") + traceback.print_exc() + JANITOR_CPP = False + +T = TypeVar("T") + + +# Implementation from nltk source +# https://www.nltk.org/_modules/nltk/util.html +def form_ngrams(sequence: Iterator[T], n: int) -> Iterator[Tuple[T, ...]]: + history = [] + while n > 1: + # PEP 479, prevent RuntimeError from being raised when StopIteration bubbles out of generator + try: + next_item = next(sequence) + except StopIteration: + # no more data, terminate the generator + return + history.append(next_item) + n -= 1 + for item in sequence: + history.append(item) + yield tuple(history) + del history[0] + + +def word_ngrams(s: str, n: int) -> Iterator[str]: + """Splits a string into ngram words""" + tokens = s.split() # not a generator :( + ngram_seqs = form_ngrams(iter(tokens), n) + return (" ".join(ngram) for ngram in ngram_seqs) + + +# Does character sequences only - combined faster function to play around with later +# def word_ngrams_indices_combined(sequence, n): +# current_word = "" +# history = [] +# gap = False; +# start = 0 +# end = 0 +# for character in sequence: +# if character == " ": +# if not gap: +# gap = True +# history.append(current_word) +# end += len(current_word) - 1 +# current_word = "" +# if len(history) == n: +# yield (tuple(history), start, end) +# del history[0] +# start = end + 1 +# end = start +# else: +# gap = False +# current_word += character + + +# https://stackoverflow.com/questions/13734451/string-split-with-indices-in-python +def split_indices(s: str) -> Iterator[Tuple[str, Tuple[int, int]]]: + """Splits a string on whitespaces and records the indices of each in the original string. + @:return generator((word, (start_idx, end_idx)), ...) + """ + return ((m.group(0), (m.start(), m.end() - 1)) for m in re.finditer(r"\S+", s)) + + +def word_ngrams_indices(s: str, n: int) -> Iterator[Tuple[str, Tuple[int, int]]]: + """Splits a string into pairs of (ngram words, their start/end indices)""" + tokens_with_indices = split_indices(s) + + # Generator of ngrams of (word, idx_pairs) + # ( + # [(word, (start,end)), (word, (start, end))...], + # [(word, (start, end)), ...], + # ... + # ) + ngram_seqs_with_indices = form_ngrams(tokens_with_indices, n) + + # Generator of pairs of word and index ngrams + # ( + # ([word, word, ...], [(start,end), (start,end), ...]), + # ... + # ) + ngram_indices_pairs = ( + zip(*ngram_with_indices) for ngram_with_indices in ngram_seqs_with_indices + ) + + # Generator of ( (word_ngram, (start, end)), (word_ngram, start, end)), ...) + return ( + (" ".join(ngram_seq), (indices[0][0], indices[-1][1])) + for ngram_seq, indices in ngram_indices_pairs + ) + + +class Janitor: + # FIXME delete_chars: Should anything else go here? Special chars? + def __init__( + self, + ngram_n: int = 13, + window_to_remove: int = 200, + too_dirty_cutoff: int = 10, + minimum_slice_length: int = 200, + delete_chars: str = string.punctuation, + ) -> None: + self.ngram_n = ngram_n + self.window_to_remove = window_to_remove + self.too_dirty_cutoff = too_dirty_cutoff + self.minimum_slice_length = minimum_slice_length + self.delete_chars = delete_chars + + self.dirt_ngrams = set() + + # If in python, we'll translate uppercase to lowercase and delete naughty characters. + # This is fast by python standards + # https://stackoverflow.com/questions/638893/what-is-the-most-efficient-way-in-python-to-convert-a-string-to-all-lowercase-st + self.translation_table = str.maketrans( + string.ascii_lowercase + string.ascii_uppercase, # These characters + string.ascii_lowercase * 2, # Become these characters + self.delete_chars, # These are deleted + ) + + ############## + # I/O for saving contamination ngrams + ############## + + def save_contamination_ngrams(self, filename: str) -> None: + with open(filename, "wb") as fp: + pickle.dump(filename, fp) + + def load_contamination_ngrams(self, filename: str) -> None: + with open(filename, "rb") as fp: + self.dirt_ngrams = pickle.load(fp) + + ############## + # Call these :) + ############## + + def register_contaminant(self, dirt_string: str) -> None: + """Register a string as contamination to be removed, e.g. a test set + This breaks the dirt_string into ngrams to store for future cleaning""" + if JANITOR_CPP: + return self.register_contaminant_cpp(dirt_string) + else: + print("WARNING: Janitor running in python mode") + return self.register_contaminant_python(dirt_string) + + def clean(self, dirty_string: str) -> List[str]: + """Clean a string (e.g. a training set) by removing all ngrams previously + registered as contaminants. Returns a list of clean chunks, or empty if + the string was too dirty""" + if JANITOR_CPP: + return self.clean_cpp(dirty_string) + else: + print("WARNING: Janitor running in python mode") + return self.clean_python(dirty_string) + + def _split_chunks( + self, dirty_string: str, dirty_parts: Sequence[Tuple] + ) -> List[str]: + clean_chunks = [] + splice_idx = 0 + end = -1 + for i, (ngram, start, end) in enumerate(dirty_parts): + if i >= self.too_dirty_cutoff: + return [] + start = max(0, start - self.window_to_remove) + end = min(len(dirty_string), end + self.window_to_remove) + + if start - splice_idx > self.minimum_slice_length: + clean_chunks.append(dirty_string[splice_idx:start]) + splice_idx = end + + if end < len(dirty_string) - self.minimum_slice_length: + clean_chunks.append(dirty_string[end + 1 :]) + + return clean_chunks + + ############## + # Fast C++ + ############## + + def register_contaminant_cpp(self, dirt_string) -> None: + self.dirt_ngrams.update( + janitor_util.clean_ngram(dirt_string, self.delete_chars, self.ngram_n) + ) + + def clean_cpp(self, dirty_string: str) -> List[str]: + contamination_indices = janitor_util.clean_ngram_with_indices( + dirty_string, self.delete_chars, self.ngram_n + ) + return self._split_chunks(dirty_string, contamination_indices) + + ############## + # Slow python + ############## + + def normalize_string(self, s: str) -> str: + return s.translate(self.translation_table) + + def register_contaminant_python(self, dirt_string: str) -> None: + self.dirt_ngrams.update( + word_ngrams(self.normalize_string(dirt_string), self.ngram_n) + ) + + def clean_python(self, dirty_string: str) -> List[str]: + contamination_indices = ( + (None, *idx_pair) + for dirty_ngram, idx_pair in word_ngrams_indices(dirty_string, self.ngram_n) + if self.normalize_string(dirty_ngram) in self.dirt_ngrams + ) + return self._split_chunks(dirty_string, contamination_indices) + + +################################################################## +# Tests +################################################################# + +# def print_cpp(): +# source = """ ,, I'm a very !dirty,, ,, dirty boy. Clean me daddy. \n\nhe he he hehe heh. lastword """ * 2 + +# for i in range(1, 10, 2): +# pprint(janitor_util.clean_ngram(source, string.punctuation, i)) +# for ngram, start, end in \ +# janitor_util.clean_ngram_with_indices(source, string.punctuation, i): +# print(ngram, "\t", start, end, source[start:end].replace("\n", "\\n")) + + +# def test_cpp(): +# source = """ ,, I'm a very !dirty,, ,, dirty boy. Clean me daddy. \n\nhe he he hehe heh. lastword """ * 2 +# contaminant = "dirty boy. Clean he he" + +# jan_python = Janitor() +# jan_cpp = Janitor() + +# jan_python.register_contaminant_python(contaminant) +# jan_cpp.register_contaminant(contaminant) + +# assert jan_python.dirt_ngrams == jan_cpp.dirt_ngrams, (jan_python.dirt_ngrams, jan_cpp.dirt_ngrams) + +# assert jan_python.clean_python(source) == jan_cpp.clean(source), \ +# (jan_python.clean_python(source), jan_cpp.clean(source)) + +# print("Passed test, python==cpp") + + +# def benchmark(): +# # Download and put in data folder: enwik8 (100 MB) from https://cs.fit.edu/~mmahoney/compression/textdata.html +# setup = \ +# """ +# with open("data/enwik8", "r") as f: +# data = f.read() +# jan = Janitor(too_dirty_cutoff=1000) +# jan.register_contaminant(''' +# theories is that there is a connection between "geekdom" and autism. +# This is hinted, for instance, by a ''Wired Magazine'' article in 2001 entitled " +# The [[Geek]] Syndrome", which is a point argued by many in the autism rights +# movement{{ref|Wired}}. This article, many professionals assert, is just one example of +# the media's application of mental disease labels to what is actually variant normal behavior +# &mdash;they argue that shyness, lack of athletic ability or social skills, and intellectual +# interests, even when they seem unusual to others, are not in themselves signs of autism or +# Asperger's syndrome. Others assert that it is actually the medical profession which is applying +# mental disease labels to children who in the past would have simply been accepted as a little +# different or even labeled 'gifted'. See [[clinomorphism]] for further discussion of this issue. +# Due to the recent publicity surrounding autism and autis +# ultan Al Nahyan]] granted [[Petroleum]] concessions, and oil was first found in 1958. At first, +# oil money had a marginal impact. A few lowrise concete buildings were erected, and the first +# paved road was completed in 1961, but Sheikh Shakbut, uncertain whether the new oil royalties +# would last, took a cautious approach, preferring to save the revenue rather than investing it in +# development. His brother, [[Zayed bin Sultan Al Nahayan]], saw that oil wealth had the potential +# to transform Abu Dhabi. The ruling Al Nahayan family decided that Sheikh Zayed should replace his +# brother as Ruler and carry out his vision of developing the country. On [[August 6]], [[1966]], +# with the assistance of the British, Sheikh Zayed became the new ruler. See generally, Al-Fahim, M, +# ''From Rags to Riches: A Story of Abu Dhabi'', Chapter Six (London Centre of Arab Studies, 1995), +# ISBN 1 900404 00 1. With the announcement by Britain in 1968 that it would withdraw from the +# Gulf area by 1971, Sheikh Zayed became the main driving force behind the formation of the +# [[United Arab Emirates]]. After the Emirates gained independence in 1971, +# ''') +# """ + +# n = 1 +# print(f"Timing {n} run on 100 MB") +# print("Register contaminant") +# # print("\tPython", timeit.timeit("jan.register_contaminant_python(data)", setup=setup, globals=globals(), number=n)) +# print("\tCpp", timeit.timeit("jan.register_contaminant(data)", setup=setup, globals=globals(), number=n)) + +# print("Clean") +# # print("\tPython", timeit.timeit("jan.clean_python(data)", setup=setup, globals=globals(), number=n)) +# print("\tCpp", timeit.timeit("jan.clean(data)", setup=setup, globals=globals(), number=n)) + + +# def test_janitor_general(): +# source = """ ,, I'm a very !dirty,, ,, dirty boy. Clean me daddy. \n\nhe he he hehe heh. lastword """ * 2 +# contaminant = "dirty boy. Clean he he" + +# jan = Janitor(ngram_n=3) +# jan.register_contaminant(contaminant) +# cleaned = " ".join(jan.clean(source)) +# for contam in jan.dirt_ngrams: +# assert contam not in cleaned, contam + +# filename = "data/saved_contam" +# jan.save_contamination_ngrams(filename) + +# jan = Janitor(ngram_n=3) +# jan.load_contamination_ngrams(filename) +# cleaned = " ".join(jan.clean(source)) +# for contam in jan.dirt_ngrams: +# assert contam not in cleaned, contam + + +# if __name__ == "__main__": +# test() +# # print_cpp() +# # test_cpp() +# # benchmark() diff --git a/Prism/LLaDA/LLaDA_Baseline/dllm_eval/loggers/__init__.py b/Prism/LLaDA/LLaDA_Baseline/dllm_eval/loggers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..02b7a6834c6486fde35ef02d715e90be3fba223a --- /dev/null +++ b/Prism/LLaDA/LLaDA_Baseline/dllm_eval/loggers/__init__.py @@ -0,0 +1,2 @@ +from .evaluation_tracker import EvaluationTracker +from .wandb_logger import WandbLogger diff --git a/Prism/LLaDA/LLaDA_Baseline/dllm_eval/loggers/evaluation_tracker.py b/Prism/LLaDA/LLaDA_Baseline/dllm_eval/loggers/evaluation_tracker.py new file mode 100644 index 0000000000000000000000000000000000000000..7f88978e73a8fad88d83a9563e85090b8c7e5594 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Baseline/dllm_eval/loggers/evaluation_tracker.py @@ -0,0 +1,530 @@ +import json +import logging +import os +import re +import time +from collections import defaultdict +from dataclasses import asdict, dataclass +from datetime import datetime +from pathlib import Path + +from datasets import load_dataset +from datasets.utils.metadata import MetadataConfigs +from huggingface_hub import ( + DatasetCard, + DatasetCardData, + HfApi, + hf_hub_url, +) +from huggingface_hub.utils import build_hf_headers, get_session, hf_raise_for_status + +from dllm_eval.utils import ( + get_file_datetime, + get_file_task_name, + get_results_filenames, + get_sample_results_filenames, + handle_non_serializable, + hash_string, + sanitize_list, + sanitize_model_name, + sanitize_task_name, +) + + +eval_logger = logging.getLogger(__name__) + + +@dataclass(init=False) +class GeneralConfigTracker: + """ + Tracker for the evaluation parameters. + + Attributes: + model_source (str): Source of the model (e.g. Hugging Face, GGUF, etc.) + model_name (str): Name of the model. + model_name_sanitized (str): Sanitized model name for directory creation. + start_time (float): Start time of the experiment. Logged at class init. + end_time (float): Start time of the experiment. Logged when calling [`GeneralConfigTracker.log_end_time`] + total_evaluation_time_seconds (str): Inferred total evaluation time in seconds (from the start and end times). + """ + + model_source: str = None + model_name: str = None + model_name_sanitized: str = None + system_instruction: str = None + system_instruction_sha: str = None + fewshot_as_multiturn: bool = None + chat_template: str = None + chat_template_sha: str = None + start_time: float = None + end_time: float = None + total_evaluation_time_seconds: str = None + + def __init__(self) -> None: + """Starts the evaluation timer.""" + self.start_time = time.perf_counter() + + @staticmethod + def _get_model_name(model_args: str) -> str: + """Extracts the model name from the model arguments.""" + + def extract_model_name(model_args: str, key: str) -> str: + """Extracts the model name from the model arguments using a key.""" + args_after_key = model_args.split(key)[1] + return args_after_key.split(",")[0] + + # order does matter, e.g. peft and delta are provided together with pretrained + prefixes = ["peft=", "delta=", "pretrained=", "model=", "path=", "engine="] + for prefix in prefixes: + if prefix in model_args: + return extract_model_name(model_args, prefix) + return "" + + def log_experiment_args( + self, + model_source: str, + model_args: str, + system_instruction: str, + chat_template: str, + fewshot_as_multiturn: bool, + ) -> None: + """Logs model parameters and job ID.""" + self.model_source = model_source + self.model_name = GeneralConfigTracker._get_model_name(model_args) + self.model_name_sanitized = sanitize_model_name(self.model_name) + self.system_instruction = system_instruction + self.system_instruction_sha = ( + hash_string(system_instruction) if system_instruction else None + ) + self.chat_template = chat_template + self.chat_template_sha = hash_string(chat_template) if chat_template else None + self.fewshot_as_multiturn = fewshot_as_multiturn + + def log_end_time(self) -> None: + """Logs the end time of the evaluation and calculates the total evaluation time.""" + self.end_time = time.perf_counter() + self.total_evaluation_time_seconds = str(self.end_time - self.start_time) + + +class EvaluationTracker: + """ + Keeps track and saves relevant information of the evaluation process. + Compiles the data from trackers and writes it to files, which can be published to the Hugging Face hub if requested. + """ + + def __init__( + self, + output_path: str = None, + hub_results_org: str = "", + hub_repo_name: str = "", + details_repo_name: str = "", + results_repo_name: str = "", + push_results_to_hub: bool = False, + push_samples_to_hub: bool = False, + public_repo: bool = False, + token: str = "", + leaderboard_url: str = "", + point_of_contact: str = "", + gated: bool = False, + ) -> None: + """ + Creates all the necessary loggers for evaluation tracking. + + Args: + output_path (str): Path to save the results. If not provided, the results won't be saved. + hub_results_org (str): The Hugging Face organization to push the results to. If not provided, the results will be pushed to the owner of the Hugging Face token. + hub_repo_name (str): The name of the Hugging Face repository to push the results to. If not provided, the results will be pushed to `lm-eval-results`. + details_repo_name (str): The name of the Hugging Face repository to push the details to. If not provided, the results will be pushed to `lm-eval-results`. + result_repo_name (str): The name of the Hugging Face repository to push the results to. If not provided, the results will not be pushed and will be found in the details_hub_repo. + push_results_to_hub (bool): Whether to push the results to the Hugging Face hub. + push_samples_to_hub (bool): Whether to push the samples to the Hugging Face hub. + public_repo (bool): Whether to push the results to a public or private repository. + token (str): Token to use when pushing to the Hugging Face hub. This token should have write access to `hub_results_org`. + leaderboard_url (str): URL to the leaderboard on the Hugging Face hub on the dataset card. + point_of_contact (str): Contact information on the Hugging Face hub dataset card. + gated (bool): Whether to gate the repository. + """ + self.general_config_tracker = GeneralConfigTracker() + + self.output_path = output_path + self.push_results_to_hub = push_results_to_hub + self.push_samples_to_hub = push_samples_to_hub + self.public_repo = public_repo + self.leaderboard_url = leaderboard_url + self.point_of_contact = point_of_contact + self.api = HfApi(token=token) if token else None + self.gated_repo = gated + + if not self.api and (push_results_to_hub or push_samples_to_hub): + raise ValueError( + "Hugging Face token is not defined, but 'push_results_to_hub' or 'push_samples_to_hub' is set to True. " + "Please provide a valid Hugging Face token by setting the HF_TOKEN environment variable." + ) + + if ( + self.api + and hub_results_org == "" + and (push_results_to_hub or push_samples_to_hub) + ): + hub_results_org = self.api.whoami()["name"] + eval_logger.warning( + f"hub_results_org was not specified. Results will be pushed to '{hub_results_org}'." + ) + + if hub_repo_name == "": + details_repo_name = ( + details_repo_name if details_repo_name != "" else "lm-eval-results" + ) + results_repo_name = ( + results_repo_name if results_repo_name != "" else details_repo_name + ) + else: + details_repo_name = hub_repo_name + results_repo_name = hub_repo_name + eval_logger.warning( + "hub_repo_name was specified. Both details and results will be pushed to the same repository. Using hub_repo_name is no longer recommended, details_repo_name and results_repo_name should be used instead." + ) + + self.details_repo = f"{hub_results_org}/{details_repo_name}" + self.details_repo_private = f"{hub_results_org}/{details_repo_name}-private" + self.results_repo = f"{hub_results_org}/{results_repo_name}" + self.results_repo_private = f"{hub_results_org}/{results_repo_name}-private" + + def save_results_aggregated( + self, + results: dict, + samples: dict, + ) -> None: + """ + Saves the aggregated results and samples to the output path and pushes them to the Hugging Face hub if requested. + + Args: + results (dict): The aggregated results to save. + samples (dict): The samples results to save. + """ + self.general_config_tracker.log_end_time() + + if self.output_path: + try: + eval_logger.info("Saving results aggregated") + + # calculate cumulative hash for each task - only if samples are provided + task_hashes = {} + if samples: + for task_name, task_samples in samples.items(): + sample_hashes = [ + s["doc_hash"] + s["prompt_hash"] + s["target_hash"] + for s in task_samples + ] + task_hashes[task_name] = hash_string("".join(sample_hashes)) + + # update initial results dict + results.update({"task_hashes": task_hashes}) + results.update(asdict(self.general_config_tracker)) + dumped = json.dumps( + results, + indent=2, + default=handle_non_serializable, + ensure_ascii=False, + ) + + path = Path(self.output_path if self.output_path else Path.cwd()) + self.date_id = datetime.now().isoformat().replace(":", "-") + if path.suffix == ".json": + path.parent.mkdir(parents=True, exist_ok=True) + file_results_aggregated = path.with_name( + f"{path.stem}_{self.date_id}.json" + ) + else: + path.mkdir(parents=True, exist_ok=True) + file_results_aggregated = path.joinpath( + f"results_{self.date_id}.json" + ) + + file_results_aggregated.open("w", encoding="utf-8").write(dumped) + + if self.api and self.push_results_to_hub: + repo_id = ( + self.results_repo + if self.public_repo + else self.results_repo_private + ) + self.api.create_repo( + repo_id=repo_id, + repo_type="dataset", + private=not self.public_repo, + exist_ok=True, + ) + self.api.upload_file( + repo_id=repo_id, + path_or_fileobj=str(file_results_aggregated), + path_in_repo=os.path.join( + self.general_config_tracker.model_name, + file_results_aggregated.name, + ), + repo_type="dataset", + commit_message=f"Adding aggregated results for {self.general_config_tracker.model_name}", + ) + eval_logger.info( + "Successfully pushed aggregated results to the Hugging Face Hub. " + f"You can find them at: {repo_id}" + ) + + except Exception as e: + eval_logger.warning("Could not save results aggregated") + eval_logger.info(repr(e)) + else: + eval_logger.info( + "Output path not provided, skipping saving results aggregated" + ) + + def save_results_samples( + self, + task_name: str, + samples: dict, + ) -> None: + """ + Saves the samples results to the output path and pushes them to the Hugging Face hub if requested. + + Args: + task_name (str): The task name to save the samples for. + samples (dict): The samples results to save. + """ + if self.output_path: + try: + eval_logger.info(f"Saving per-sample results for: {task_name}") + + path = Path(self.output_path if self.output_path else Path.cwd()) + if path.suffix == ".json": + path = path.parent + path.mkdir(parents=True, exist_ok=True) + + file_results_samples = path.joinpath( + f"samples_{task_name}_{self.date_id}.jsonl" + ) + + for sample in samples: + # we first need to sanitize arguments and resps + # otherwise we won't be able to load the dataset + # using the datasets library + arguments = {} + for i, arg in enumerate(sample["arguments"]): + arguments[f"gen_args_{i}"] = {} + for j, tmp in enumerate(arg): + arguments[f"gen_args_{i}"][f"arg_{j}"] = tmp + + sample["resps"] = sanitize_list(sample["resps"]) + sample["filtered_resps"] = sanitize_list(sample["filtered_resps"]) + sample["arguments"] = arguments + sample["target"] = str(sample["target"]) + + sample_dump = ( + json.dumps( + sample, + default=handle_non_serializable, + ensure_ascii=False, + ) + + "\n" + ) + + with open(file_results_samples, "a", encoding="utf-8") as f: + f.write(sample_dump) + + if self.api and self.push_samples_to_hub: + repo_id = ( + self.details_repo + if self.public_repo + else self.details_repo_private + ) + self.api.create_repo( + repo_id=repo_id, + repo_type="dataset", + private=not self.public_repo, + exist_ok=True, + ) + try: + if self.gated_repo: + headers = build_hf_headers() + r = get_session().put( + url=f"https://huggingface.co/api/datasets/{repo_id}/settings", + headers=headers, + json={"gated": "auto"}, + ) + hf_raise_for_status(r) + except Exception as e: + eval_logger.warning("Could not gate the repository") + eval_logger.info(repr(e)) + self.api.upload_folder( + repo_id=repo_id, + folder_path=str(path), + path_in_repo=self.general_config_tracker.model_name_sanitized, + repo_type="dataset", + commit_message=f"Adding samples results for {task_name} to {self.general_config_tracker.model_name}", + ) + eval_logger.info( + f"Successfully pushed sample results for task: {task_name} to the Hugging Face Hub. " + f"You can find them at: {repo_id}" + ) + + except Exception as e: + eval_logger.warning("Could not save sample results") + eval_logger.info(repr(e)) + else: + eval_logger.info("Output path not provided, skipping saving sample results") + + def recreate_metadata_card(self) -> None: + """ + Creates a metadata card for the evaluation results dataset and pushes it to the Hugging Face hub. + """ + + eval_logger.info("Recreating metadata card") + repo_id = self.details_repo if self.public_repo else self.details_repo_private + + files_in_repo = self.api.list_repo_files(repo_id=repo_id, repo_type="dataset") + results_files = get_results_filenames(files_in_repo) + sample_files = get_sample_results_filenames(files_in_repo) + + # Build a dictionary to store the latest evaluation datetime for: + # - Each tested model and its aggregated results + # - Each task and sample results, if existing + # i.e. { + # "org__model_name__gsm8k": "2021-09-01T12:00:00", + # "org__model_name__ifeval": "2021-09-01T12:00:00", + # "org__model_name__results": "2021-09-01T12:00:00" + # } + latest_task_results_datetime = defaultdict(lambda: datetime.min.isoformat()) + + for file_path in sample_files: + file_path = Path(file_path) + filename = file_path.name + model_name = file_path.parent + task_name = get_file_task_name(filename) + results_datetime = get_file_datetime(filename) + task_name_sanitized = sanitize_task_name(task_name) + # Results and sample results for the same model and task will have the same datetime + samples_key = f"{model_name}__{task_name_sanitized}" + results_key = f"{model_name}__results" + latest_datetime = max( + latest_task_results_datetime[samples_key], + results_datetime, + ) + latest_task_results_datetime[samples_key] = latest_datetime + latest_task_results_datetime[results_key] = max( + latest_task_results_datetime[results_key], + latest_datetime, + ) + + # Create metadata card + card_metadata = MetadataConfigs() + + # Add the latest aggregated results to the metadata card for easy access + for file_path in results_files: + file_path = Path(file_path) + results_filename = file_path.name + model_name = file_path.parent + eval_date = get_file_datetime(results_filename) + eval_date_sanitized = re.sub(r"[^\w\.]", "_", eval_date) + results_filename = Path("**") / Path(results_filename).name + config_name = f"{model_name}__results" + sanitized_last_eval_date_results = re.sub( + r"[^\w\.]", "_", latest_task_results_datetime[config_name] + ) + + if eval_date_sanitized == sanitized_last_eval_date_results: + # Ensure that all results files are listed in the metadata card + current_results = card_metadata.get(config_name, {"data_files": []}) + current_results["data_files"].append( + {"split": eval_date_sanitized, "path": [str(results_filename)]} + ) + card_metadata[config_name] = current_results + # If the results file is the newest, update the "latest" field in the metadata card + card_metadata[config_name]["data_files"].append( + {"split": "latest", "path": [str(results_filename)]} + ) + + # Add the tasks details configs + for file_path in sample_files: + file_path = Path(file_path) + filename = file_path.name + model_name = file_path.parent + task_name = get_file_task_name(filename) + eval_date = get_file_datetime(filename) + task_name_sanitized = sanitize_task_name(task_name) + eval_date_sanitized = re.sub(r"[^\w\.]", "_", eval_date) + results_filename = Path("**") / Path(filename).name + config_name = f"{model_name}__{task_name_sanitized}" + sanitized_last_eval_date_results = re.sub( + r"[^\w\.]", "_", latest_task_results_datetime[config_name] + ) + if eval_date_sanitized == sanitized_last_eval_date_results: + # Ensure that all sample results files are listed in the metadata card + current_details_for_task = card_metadata.get( + config_name, {"data_files": []} + ) + current_details_for_task["data_files"].append( + {"split": eval_date_sanitized, "path": [str(results_filename)]} + ) + card_metadata[config_name] = current_details_for_task + # If the samples results file is the newest, update the "latest" field in the metadata card + card_metadata[config_name]["data_files"].append( + {"split": "latest", "path": [str(results_filename)]} + ) + + # Get latest results and extract info to update metadata card examples + latest_datetime = max(latest_task_results_datetime.values()) + latest_model_name = max( + latest_task_results_datetime, key=lambda k: latest_task_results_datetime[k] + ) + last_results_file = [ + f for f in results_files if latest_datetime.replace(":", "-") in f + ][0] + last_results_file_path = hf_hub_url( + repo_id=repo_id, filename=last_results_file, repo_type="dataset" + ) + latest_results_file = load_dataset( + "json", data_files=last_results_file_path, split="train" + ) + results_dict = latest_results_file["results"][0] + new_dictionary = {"all": results_dict} + new_dictionary.update(results_dict) + results_string = json.dumps(new_dictionary, indent=4) + + dataset_summary = ( + "Dataset automatically created during the evaluation run of model " + ) + if self.general_config_tracker.model_source == "hf": + dataset_summary += f"[{self.general_config_tracker.model_name}](https://huggingface.co/{self.general_config_tracker.model_name})\n" + else: + dataset_summary += f"{self.general_config_tracker.model_name}\n" + dataset_summary += ( + f"The dataset is composed of {len(card_metadata) - 1} configuration(s), each one corresponding to one of the evaluated task.\n\n" + f"The dataset has been created from {len(results_files)} run(s). Each run can be found as a specific split in each " + 'configuration, the split being named using the timestamp of the run.The "train" split is always pointing to the latest results.\n\n' + 'An additional configuration "results" store all the aggregated results of the run.\n\n' + "To load the details from a run, you can for instance do the following:\n" + ) + if self.general_config_tracker.model_source == "hf": + dataset_summary += ( + "```python\nfrom datasets import load_dataset\n" + f'data = load_dataset(\n\t"{repo_id}",\n\tname="{latest_model_name}",\n\tsplit="latest"\n)\n```\n\n' + ) + dataset_summary += ( + "## Latest results\n\n" + f"These are the [latest results from run {latest_datetime}]({last_results_file_path.replace('/resolve/', '/blob/')}) " + "(note that there might be results for other tasks in the repos if successive evals didn't cover the same tasks. " + 'You find each in the results and the "latest" split for each eval):\n\n' + f"```python\n{results_string}\n```" + ) + card_data = DatasetCardData( + dataset_summary=dataset_summary, + repo_url=f"https://huggingface.co/{self.general_config_tracker.model_name}", + pretty_name=f"Evaluation run of {self.general_config_tracker.model_name}", + leaderboard_url=self.leaderboard_url, + point_of_contact=self.point_of_contact, + ) + card_metadata.to_dataset_card_data(card_data) + card = DatasetCard.from_template( + card_data, + pretty_name=card_data.pretty_name, + ) + card.push_to_hub(repo_id, repo_type="dataset") diff --git a/Prism/LLaDA/LLaDA_Baseline/dllm_eval/loggers/utils.py b/Prism/LLaDA/LLaDA_Baseline/dllm_eval/loggers/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ba795edb72d7b665a2c0fe6d4f3e3a5ed91b6940 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Baseline/dllm_eval/loggers/utils.py @@ -0,0 +1,149 @@ +import logging +import os +import re +import subprocess +from importlib.metadata import version +from pathlib import Path +from typing import Any, Dict, Optional, Tuple, Union + +import numpy as np +from torch.utils.collect_env import get_pretty_env_info +from transformers import __version__ as trans_version + + +logger = logging.getLogger(__name__) + + +def remove_none_pattern(input_string: str) -> Tuple[str, bool]: + """Remove the ',none' substring from the input_string if it exists at the end. + + Args: + input_string (str): The input string from which to remove the ',none' substring. + + Returns: + Tuple[str, bool]: A tuple containing the modified input_string with the ',none' substring removed + and a boolean indicating whether the modification was made (True) or not (False). + """ + # Define the pattern to match ',none' at the end of the string + pattern = re.compile(r",none$") + + # Use sub() to replace ',none' with an empty string + result = re.sub(pattern, "", input_string) + + # check if the input_string changed + removed = result != input_string + + return result, removed + + +def _handle_non_serializable(o: Any) -> Union[int, str, list]: + """Handle non-serializable objects by converting them to serializable types. + + Args: + o (Any): The object to be handled. + + Returns: + Union[int, str, list]: The converted object. If the object is of type np.int64 or np.int32, + it will be converted to int. If the object is of type set, it will be converted + to a list. Otherwise, it will be converted to str. + """ + if isinstance(o, np.int64) or isinstance(o, np.int32): + return int(o) + elif isinstance(o, set): + return list(o) + else: + return str(o) + + +def get_commit_from_path(repo_path: Union[Path, str]) -> Optional[str]: + try: + git_folder = Path(repo_path, ".git") + if git_folder.is_file(): + git_folder = Path( + git_folder.parent, + git_folder.read_text(encoding="utf-8").split("\n")[0].split(" ")[-1], + ) + if Path(git_folder, "HEAD").exists(): + head_name = ( + Path(git_folder, "HEAD") + .read_text(encoding="utf-8") + .split("\n")[0] + .split(" ")[-1] + ) + head_ref = Path(git_folder, head_name) + git_hash = head_ref.read_text(encoding="utf-8").replace("\n", "") + else: + git_hash = None + except Exception as err: + logger.debug( + f"Failed to retrieve a Git commit hash from path: {str(repo_path)}. Error: {err}" + ) + return None + return git_hash + + +def get_git_commit_hash(): + """ + Gets the git commit hash of your current repo (if it exists). + Source: https://github.com/EleutherAI/gpt-neox/blob/b608043be541602170bfcfb8ec9bf85e8a0799e0/megatron/neox_arguments/neox_args.py#L42 + """ + try: + git_hash = subprocess.check_output(["git", "describe", "--always"]).strip() + git_hash = git_hash.decode() + except (subprocess.CalledProcessError, FileNotFoundError): + # FileNotFoundError occurs when git not installed on system + git_hash = get_commit_from_path(os.getcwd()) # git hash of repo if exists + return git_hash + + +def add_env_info(storage: Dict[str, Any]): + try: + pretty_env_info = get_pretty_env_info() + except Exception as err: + pretty_env_info = str(err) + try: + dllm_eval_version = version("dllm_eval") + except Exception as err: + dllm_eval_version = str(err) + transformers_version = trans_version + upper_dir_commit = get_commit_from_path( + Path(os.getcwd(), "..") + ) # git hash of upper repo if exists + added_info = { + "pretty_env_info": pretty_env_info, + "transformers_version": transformers_version, + "dllm_eval_version": dllm_eval_version, + "upper_git_hash": upper_dir_commit, # in case this repo is submodule + } + storage.update(added_info) + + +def add_tokenizer_info(storage: Dict[str, Any], lm): + if getattr(lm, "tokenizer", False): + try: + tokenizer_info = { + "tokenizer_pad_token": [ + lm.tokenizer.pad_token, + str(lm.tokenizer.pad_token_id), + ], + "tokenizer_eos_token": [ + lm.tokenizer.eos_token, + str(lm.tokenizer.eos_token_id), + ], + "tokenizer_bos_token": [ + lm.tokenizer.bos_token, + str(lm.tokenizer.bos_token_id), + ], + "eot_token_id": getattr(lm, "eot_token_id", None), + "max_length": getattr(lm, "max_length", None), + } + storage.update(tokenizer_info) + except Exception as err: + logger.debug( + f"Logging detailed tokenizer info failed with {err}, skipping..." + ) + # seems gguf and textsynth do not have tokenizer + else: + logger.debug( + "LM does not have a 'tokenizer' attribute, not logging tokenizer metadata to results." + ) diff --git a/Prism/LLaDA/LLaDA_Baseline/dllm_eval/loggers/wandb_logger.py b/Prism/LLaDA/LLaDA_Baseline/dllm_eval/loggers/wandb_logger.py new file mode 100644 index 0000000000000000000000000000000000000000..9c0859b3c8e90437f21b6f06143b14941a7a96d2 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Baseline/dllm_eval/loggers/wandb_logger.py @@ -0,0 +1,358 @@ +import copy +import json +import logging +from typing import Any, Dict, List, Literal, Tuple + +import numpy as np +import pandas as pd +from packaging.version import Version + +from dllm_eval.loggers.utils import _handle_non_serializable, remove_none_pattern + + +logger = logging.getLogger(__name__) + + +def get_wandb_printer() -> Literal["Printer"]: + """Returns a wandb printer instance for pretty stdout.""" + from wandb.sdk.lib.printer import new_printer + + printer = new_printer() + return printer + + +class WandbLogger: + def __init__(self, init_args=None, config_args=None) -> None: + """Attaches to wandb logger if already initialized. Otherwise, passes init_args to wandb.init() and config_args to wandb.config.update() + + Args: + init_args Optional[Dict]: Arguments for init configuration. + config_args Optional[Dict]: Arguments for config + + Parse and log the results returned from evaluator.simple_evaluate() with: + wandb_logger.post_init(results) + wandb_logger.log_eval_result() + wandb_logger.log_eval_samples(results["samples"]) + """ + try: + import wandb + + assert Version(wandb.__version__) >= Version("0.13.6") + if Version(wandb.__version__) < Version("0.13.6"): + wandb.require("report-editing:v0") + except Exception as e: + logger.warning( + "To use the wandb reporting functionality please install wandb>=0.13.6.\n" + "To install the latest version of wandb run `pip install wandb --upgrade`\n" + f"{e}" + ) + + self.wandb_args: Dict[str, Any] = init_args or {} + self.wandb_config_args: Dict[str, Any] = config_args or {} + + # pop the step key from the args to save for all logging calls + self.step = self.wandb_args.pop("step", None) + + # initialize a W&B run + if wandb.run is None: + self.run = wandb.init(**self.wandb_args) + if self.wandb_config_args: + self.run.config.update(self.wandb_config_args) + else: + self.run = wandb.run + + self.printer = get_wandb_printer() + + def post_init(self, results: Dict[str, Any]) -> None: + self.results: Dict[str, Any] = copy.deepcopy(results) + self.task_names: List[str] = list(results.get("results", {}).keys()) + self.group_names: List[str] = list(results.get("groups", {}).keys()) + + def _get_config(self) -> Dict[str, Any]: + """Get configuration parameters.""" + self.task_configs = self.results.get("configs", {}) + cli_configs = self.results.get("config", {}) + configs = { + "task_configs": self.task_configs, + "cli_configs": cli_configs, + } + + return configs + + def _sanitize_results_dict(self) -> Tuple[Dict[str, str], Dict[str, Any]]: + """Sanitize the results dictionary.""" + _results = copy.deepcopy(self.results.get("results", dict())) + + # Remove None from the metric string name + tmp_results = copy.deepcopy(_results) + for task_name in self.task_names: + task_result = tmp_results.get(task_name, dict()) + for metric_name, metric_value in task_result.items(): + _metric_name, removed = remove_none_pattern(metric_name) + if removed: + _results[task_name][_metric_name] = metric_value + _results[task_name].pop(metric_name) + + # remove string valued keys from the results dict + wandb_summary = {} + for task in self.task_names: + task_result = _results.get(task, dict()) + for metric_name, metric_value in task_result.items(): + if isinstance(metric_value, str): + wandb_summary[f"{task}/{metric_name}"] = metric_value + + for summary_metric, summary_value in wandb_summary.items(): + _task, _summary_metric = summary_metric.split("/") + _results[_task].pop(_summary_metric) + + tmp_results = copy.deepcopy(_results) + for task_name, task_results in tmp_results.items(): + for metric_name, metric_value in task_results.items(): + _results[f"{task_name}/{metric_name}"] = metric_value + _results[task_name].pop(metric_name) + for task in self.task_names: + _results.pop(task) + + return wandb_summary, _results + + def _log_results_as_table(self) -> None: + """Generate and log evaluation results as a table to W&B.""" + columns = [ + "Version", + "Filter", + "num_fewshot", + "Metric", + "Value", + "Stderr", + ] + + def make_table(columns: List[str], key: str = "results"): + import wandb + + table = wandb.Table(columns=columns) + results = copy.deepcopy(self.results) + + for k, dic in results.get(key).items(): + if k in self.group_names and not key == "groups": + continue + version = results.get("versions").get(k) + if version == "N/A": + version = None + n = results.get("n-shot").get(k) + + for (mf), v in dic.items(): + m, _, f = mf.partition(",") + if m.endswith("_stderr"): + continue + if m == "alias": + continue + + if m + "_stderr" + "," + f in dic: + se = dic[m + "_stderr" + "," + f] + if se != "N/A": + se = "%.4f" % se + table.add_data(*[k, version, f, n, m, str(v), str(se)]) + else: + table.add_data(*[k, version, f, n, m, str(v), ""]) + + return table + + # log the complete eval result to W&B Table + table = make_table(["Tasks"] + columns, "results") + self.run.log({"evaluation/eval_results": table}, step=self.step) + + if "groups" in self.results.keys(): + table = make_table(["Groups"] + columns, "groups") + self.run.log({"evaluation/group_eval_results": table}, step=self.step) + + def _log_results_as_artifact(self) -> None: + """Log results as JSON artifact to W&B.""" + import wandb + + dumped = json.dumps( + self.results, indent=2, default=_handle_non_serializable, ensure_ascii=False + ) + artifact = wandb.Artifact("results", type="eval_results") + with artifact.new_file("results.json", mode="w", encoding="utf-8") as f: + f.write(dumped) + self.run.log_artifact(artifact) + + def log_eval_result(self) -> None: + """Log evaluation results to W&B.""" + # Log configs to wandb + configs = self._get_config() + self.run.config.update(configs, allow_val_change=self.step is not None) + + wandb_summary, self.wandb_results = self._sanitize_results_dict() + # update wandb.run.summary with items that were removed + self.run.summary.update(wandb_summary) + # Log the evaluation metrics to wandb + self.run.log(self.wandb_results, step=self.step) + # Log the evaluation metrics as W&B Table + self._log_results_as_table() + # Log the results dict as json to W&B Artifacts + self._log_results_as_artifact() + + def _generate_dataset( + self, data: List[Dict[str, Any]], config: Dict[str, Any] + ) -> pd.DataFrame: + """Generate a dataset from evaluation data. + + Args: + data (List[Dict[str, Any]]): The data to generate a dataset for. + config (Dict[str, Any]): The configuration of the task. + + Returns: + pd.DataFrame: A dataframe that is ready to be uploaded to W&B. + """ + ids = [x["doc_id"] for x in data] + labels = [x["target"] for x in data] + instance = [""] * len(ids) + resps = [""] * len(ids) + filtered_resps = [""] * len(ids) + model_outputs = {} + + metrics_list = config["metric_list"] + metrics = {} + for metric in metrics_list: + metric = metric.get("metric") + if metric in ["word_perplexity", "byte_perplexity", "bits_per_byte"]: + metrics[f"{metric}_loglikelihood"] = [x[metric][0] for x in data] + if metric in ["byte_perplexity", "bits_per_byte"]: + metrics[f"{metric}_bytes"] = [x[metric][1] for x in data] + else: + metrics[f"{metric}_words"] = [x[metric][1] for x in data] + else: + metrics[metric] = [x[metric] for x in data] + + if config["output_type"] == "loglikelihood": + instance = [x["arguments"][0][0] for x in data] + labels = [x["arguments"][0][1] for x in data] + resps = [ + f"log probability of continuation is {x['resps'][0][0][0]} " + + "\n\n" + + "continuation will {} generated with greedy sampling".format( + "not be" if not x["resps"][0][0][1] else "be" + ) + for x in data + ] + filtered_resps = [ + f"log probability of continuation is {x['filtered_resps'][0][0]} " + + "\n\n" + + "continuation will {} generated with greedy sampling".format( + "not be" if not x["filtered_resps"][0][1] else "be" + ) + for x in data + ] + elif config["output_type"] == "multiple_choice": + instance = [x["arguments"][0][0] for x in data] + choices = [ + "\n".join([f"{idx}. {y[1]}" for idx, y in enumerate(x["arguments"])]) + for x in data + ] + resps = [np.argmax([n[0][0] for n in x["resps"]]) for x in data] + filtered_resps = [ + np.argmax([n[0] for n in x["filtered_resps"]]) for x in data + ] + elif config["output_type"] == "loglikelihood_rolling": + instance = [x["arguments"][0][0] for x in data] + resps = [x["resps"][0][0] for x in data] + filtered_resps = [x["filtered_resps"][0] for x in data] + elif config["output_type"] == "generate_until": + instance = [x["arguments"][0][0] for x in data] + resps = [x["resps"][0][0] for x in data] + filtered_resps = [x["filtered_resps"][0] for x in data] + + model_outputs["raw_predictions"] = resps + model_outputs["filtered_predictions"] = filtered_resps + + df_data = { + "id": ids, + "data": instance, + } + if config["output_type"] == "multiple_choice": + df_data["choices"] = choices + + tmp_data = { + "input_len": [len(x) for x in instance], + "labels": labels, + "output_type": config["output_type"], + } + df_data.update(tmp_data) + df_data.update(model_outputs) + df_data.update(metrics) + + return pd.DataFrame(df_data) + + def _log_samples_as_artifact( + self, data: List[Dict[str, Any]], task_name: str + ) -> None: + import wandb + + # log the samples as an artifact + dumped = json.dumps( + data, + indent=2, + default=_handle_non_serializable, + ensure_ascii=False, + ) + artifact = wandb.Artifact(f"{task_name}", type="samples_by_task") + with artifact.new_file( + f"{task_name}_eval_samples.json", mode="w", encoding="utf-8" + ) as f: + f.write(dumped) + self.run.log_artifact(artifact) + # artifact.wait() + + def log_eval_samples(self, samples: Dict[str, List[Dict[str, Any]]]) -> None: + """Log evaluation samples to W&B. + + Args: + samples (Dict[str, List[Dict[str, Any]]]): Evaluation samples for each task. + """ + task_names: List[str] = [ + x for x in self.task_names if x not in self.group_names + ] + + ungrouped_tasks = [] + tasks_by_groups = {} + + for task_name in task_names: + group_names = self.task_configs[task_name].get("group", None) + if group_names: + if isinstance(group_names, str): + group_names = [group_names] + + for group_name in group_names: + if not tasks_by_groups.get(group_name): + tasks_by_groups[group_name] = [task_name] + else: + tasks_by_groups[group_name].append(task_name) + else: + ungrouped_tasks.append(task_name) + + for task_name in ungrouped_tasks: + eval_preds = samples[task_name] + + # log the samples as a W&B Table + df = self._generate_dataset(eval_preds, self.task_configs.get(task_name)) + self.run.log({f"{task_name}_eval_results": df}, step=self.step) + + # log the samples as a json file as W&B Artifact + self._log_samples_as_artifact(eval_preds, task_name) + + for group, grouped_tasks in tasks_by_groups.items(): + grouped_df = pd.DataFrame() + for task_name in grouped_tasks: + eval_preds = samples[task_name] + df = self._generate_dataset( + eval_preds, self.task_configs.get(task_name) + ) + df["group"] = group + df["task"] = task_name + grouped_df = pd.concat([grouped_df, df], ignore_index=True) + + # log the samples as a json file as W&B Artifact + self._log_samples_as_artifact(eval_preds, task_name) + + self.run.log({f"{group}_eval_results": grouped_df}, step=self.step) diff --git a/Prism/LLaDA/LLaDA_Baseline/dllm_eval/models/LLaDA.py b/Prism/LLaDA/LLaDA_Baseline/dllm_eval/models/LLaDA.py new file mode 100644 index 0000000000000000000000000000000000000000..bfe0b51b21f2ab3a584af362635e38036d415a36 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Baseline/dllm_eval/models/LLaDA.py @@ -0,0 +1,786 @@ +import logging +import os +from datetime import timedelta +from typing import Dict, List, Literal, Optional, Tuple, Union, TypeVar +import torch +import torch.nn.functional as F +import numpy as np +import transformers +import json +from accelerate import ( + Accelerator, + InitProcessGroupKwargs, +) +from datasets import Dataset +from accelerate.utils import get_max_memory +from packaging import version +from tqdm import tqdm +import torch.distributed as dist +from transformers.models.auto.modeling_auto import ( + MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, + MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES, +) +from dllm_eval.api.instance import Instance +from dllm_eval.api.model import LM, TemplateLM +from dllm_eval.api.registry import register_model +from dllm_eval.models.utils import get_dtype, configure_pad_token + +try: + from .hts_sampler import HTSSampler +except ImportError: + HTSSampler = None + +eval_logger = logging.getLogger(__name__) +T = TypeVar("T", bound="LM") + + +def add_gumbel_noise(logits, temperature): + """Add Gumbel noise for sampling""" + if temperature == 0.0: + return logits + logits = logits.to(torch.float32) + noise = torch.rand_like(logits, dtype=torch.float32) + gumbel_noise = (-torch.log(noise)) ** temperature + return logits.exp() / gumbel_noise + + +def get_num_transfer_tokens(mask_index, steps): + """Calculate number of tokens to transfer at each step""" + mask_num = mask_index.sum(dim=1, keepdim=True) + base = mask_num // steps + remainder = mask_num % steps + num_transfer_tokens = base.expand(-1, steps).clone() + if remainder.sum() > 0: + indices = torch.arange(steps, device=mask_index.device) + mask = indices.unsqueeze(0) < remainder + num_transfer_tokens[mask] += 1 + return num_transfer_tokens.to(torch.int64) + + +@torch.no_grad() +def generate_llada_v1(model, prompt, attention_mask=None, steps=128, gen_length=128, + block_length=128, temperature=0., cfg_scale=0., + remasking='low_confidence', mask_id=126336, + logits_eos_inf=False, confidence_eos_eot_inf=False): + """ + LLaDA v1 generation function + This is the original generate function from LLaDA v1 + """ + x = torch.full((prompt.shape[0], prompt.shape[1] + gen_length), mask_id, + dtype=torch.long).to(model.device) + x[:, :prompt.shape[1]] = prompt.clone() + + if attention_mask is not None: + attention_mask = torch.cat([ + attention_mask, + torch.ones((prompt.shape[0], gen_length), dtype=attention_mask.dtype, + device=model.device) + ], dim=-1) + + prompt_index = (x != mask_id) + + assert gen_length % block_length == 0 + num_blocks = gen_length // block_length + + assert steps % num_blocks == 0 + steps_per_block = steps // num_blocks + + for num_block in range(num_blocks): + block_mask_index = (x[:, prompt.shape[1] + num_block * block_length: + prompt.shape[1] + (num_block + 1) * block_length] == mask_id) + num_transfer_tokens = get_num_transfer_tokens(block_mask_index, steps_per_block) + + for i in range(steps_per_block): + mask_index = (x == mask_id) + + if cfg_scale > 0.: + un_x = x.clone() + un_x[prompt_index] = mask_id + x_ = torch.cat([x, un_x], dim=0) + if attention_mask is not None: + attention_mask_ = torch.cat([attention_mask, attention_mask], dim=0) + logits = model(x_, attention_mask=attention_mask_).logits + logits, un_logits = torch.chunk(logits, 2, dim=0) + logits = un_logits + (cfg_scale + 1) * (logits - un_logits) + else: + logits = model(x, attention_mask=attention_mask).logits + + if logits_eos_inf: + logits[:, :, 126081] = -torch.inf + + logits_with_noise = add_gumbel_noise(logits, temperature=temperature) + x0 = torch.argmax(logits_with_noise, dim=-1) + + if confidence_eos_eot_inf: + logits_with_noise[:, :, 126081] = logits[:, :, 126348] = -torch.inf + + if remasking == 'low_confidence': + p = F.softmax(logits, dim=-1) + x0_p = torch.squeeze( + torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1) + elif remasking == 'random': + x0_p = torch.rand((x0.shape[0], x0.shape[1]), device=x0.device) + else: + raise NotImplementedError(remasking) + + x0_p[:, prompt.shape[1] + (num_block + 1) * block_length:] = -np.inf + + x0 = torch.where(mask_index, x0, x) + confidence = torch.where(mask_index, x0_p, -np.inf) + + transfer_index = torch.zeros_like(x0, dtype=torch.bool, device=x0.device) + for j in range(confidence.shape[0]): + _, select_index = torch.topk(confidence[j], k=num_transfer_tokens[j, i]) + transfer_index[j, select_index] = True + + x[transfer_index] = x0[transfer_index] + + return x + + +@register_model("LLaDA") +class LLaDA(TemplateLM): + AUTO_MODEL_CLASS = transformers.AutoModel + _DEFAULT_MAX_LENGTH = 20480 + + def __init__( + self, + pretrained: Union[str, transformers.PreTrainedModel], + backend: Literal["default", "causal", "seq2seq"] = "causal", + revision: Optional[str] = "main", + subfolder: Optional[str] = None, + tokenizer: Optional[ + Union[ + str, + transformers.PreTrainedTokenizer, + transformers.PreTrainedTokenizerFast, + ] + ] = None, + truncation: Optional[bool] = False, + logits_cache: bool = True, + max_length: Optional[int] = None, + device: Optional[str] = "cuda", + dtype: Optional[Union[str, torch.dtype]] = "auto", + batch_size: Optional[Union[int]] = 1, + max_batch_size: Optional[int] = 64, + trust_remote_code: Optional[bool] = True, + use_fast_tokenizer: Optional[bool] = True, + add_bos_token: Optional[bool] = False, + escape_until: Optional[bool] = False, + prefix_token_id: Optional[int] = None, + parallelize: Optional[bool] = False, + max_memory_per_gpu: Optional[Union[int, str]] = None, + max_cpu_memory: Optional[Union[int, str]] = None, + offload_folder: Optional[Union[str, os.PathLike]] = "./offload", + peft: Optional[str] = None, + delta: Optional[str] = None, + autogptq: Optional[Union[bool, str]] = False, + gptqmodel: Optional[bool] = False, + gguf_file: Optional[str] = None, + mc_num: int = 1024, + remasking: str = "low_confidence", + mask_id: int = 126336, # LLaDA v1 default mask_id + is_check_greedy: bool = True, + assistant_prefix: Optional[str] = None, + **kwargs, + ) -> None: + super().__init__() + self.mc_num = mc_num + self.mask_id = mask_id + self.remasking = remasking + self.pretrained = pretrained + self.is_check_greedy = is_check_greedy + self.assistant_prefix = assistant_prefix + self.add_bos_token = add_bos_token + self.escape_until = escape_until + + if not isinstance(pretrained, str): + eval_logger.warning( + "`pretrained` model kwarg is not of type `str`. Many other model arguments may be ignored." + ) + assert not parallelize, ( + "`parallelize=True` is not compatible with passing pre-initialized model to `pretrained`" + ) + self._model = pretrained + self._device = self._model.device + self._config = self._model.config + gpus = 0 + else: + assert isinstance(device, str) + assert isinstance(pretrained, str) + assert isinstance(batch_size, (int, str)) + gpus = torch.cuda.device_count() + accelerator_kwargs = InitProcessGroupKwargs(timeout=timedelta(weeks=52)) + accelerator = Accelerator(kwargs_handlers=[accelerator_kwargs]) + if accelerator.num_processes > 1: + self.accelerator = accelerator + if "npu" in accelerator.device.type: + gpus = torch.npu.device_count() + if not (parallelize or accelerator.num_processes > 1): + device_list = set( + ["cuda", "cpu"] + + [f"cuda:{i}" for i in range(gpus)] + + ["mps", "mps:0"] + + [f"npu:{i}" for i in range(gpus)] + ) + if device and device in device_list: + self._device = torch.device(device) + eval_logger.info(f"Using device '{device}'") + if device in ("mps", "mps:0") and version.parse( + torch.__version__ + ) < version.parse("2.1"): + raise RuntimeError( + f"mps requires torch >= 2.1. You have {torch.__version__}" + ) + else: + eval_logger.info("Device not specified") + eval_logger.info(f"Cuda Available? {torch.cuda.is_available()}") + self._device = ( + torch.device("cuda") + if torch.cuda.is_available() + else torch.device("cpu") + ) + else: + if device != "cuda": + eval_logger.info( + f"Using `accelerate launch` or `parallelize=True`, device '{device}' will be overridden when placing model." + ) + self._device = ( + self.accelerator.device + if hasattr(self, "accelerator") + else torch.device(device) + ) + revision = str(revision) + revision = revision + ("/" + subfolder if subfolder is not None else "") + self._get_config( + pretrained, + revision=revision, + trust_remote_code=trust_remote_code, + gguf_file=gguf_file, + ) + + self._get_backend( + config=self.config, backend=backend, trust_remote_code=trust_remote_code + ) + self._create_tokenizer( + pretrained, + tokenizer, + revision=revision, + trust_remote_code=trust_remote_code, + use_fast_tokenizer=use_fast_tokenizer, + gguf_file=gguf_file, + add_bos_token=add_bos_token, + ) + + if isinstance(pretrained, str): + self._create_model( + pretrained=pretrained, + revision=revision, + dtype=dtype, + trust_remote_code=trust_remote_code, + parallelize=parallelize, + gpus=gpus, + max_memory_per_gpu=max_memory_per_gpu, + max_cpu_memory=max_cpu_memory, + offload_folder=offload_folder, + peft=peft, + delta=delta, + autogptq=autogptq, + gptqmodel=gptqmodel, + gguf_file=gguf_file, + **kwargs, + ) + + if isinstance(self.model, torch.nn.Module): + self.model.eval() + self.model.tie_weights() + + self.truncation = truncation + self.logits_cache = logits_cache + self.vocab_size = self.tokenizer.vocab_size + self.tokenizer = configure_pad_token(self.tokenizer, model_config=self.config) + self.add_bos_token = add_bos_token + + if "gemma" in getattr(self.config, "model_type", ""): + self.add_bos_token = True + eval_logger.info( + f"Model type is '{self.config.model_type}', part of the Gemma family--a BOS token will be used." + ) + + self._max_length = max_length + self.pretrained = pretrained + self.delta = delta + self.peft = peft + self.revision = revision + self.batch_schedule = 1 + self.batch_sizes = {} + self.max_batch_size = max_batch_size + + if str(batch_size).startswith("auto"): + batch_size = batch_size.split(":") + self.batch_size_per_gpu = batch_size[0] + self.batch_schedule = float(batch_size[1]) if len(batch_size) > 1 else 1 + else: + self.batch_size_per_gpu = int(batch_size) + + if isinstance(pretrained, str): + if gpus >= 1 or str(self.device) == "mps": + if not (parallelize or autogptq or hasattr(self, "accelerator")): + try: + self.model.to(self.device) + except ValueError: + eval_logger.debug( + "Failed to place model onto specified device. This may be because the model is quantized via `bitsandbytes` or `device_map` is provided." + ) + if gpus > 1: + if hasattr(self, "accelerator") and self.accelerator.num_processes > 1: + if parallelize: + eval_logger.warning( + "You are both using a HF Accelerate `device_map` and launching via `accelerate launch`." + ) + elif gpus > self.accelerator.num_processes: + eval_logger.warning( + "WARNING: The number of total system GPUs does not match the number of spawned processes." + ) + self._device = torch.device(f"{self.accelerator.device}") + self._rank = self.accelerator.local_process_index + self._world_size = self.accelerator.num_processes + else: + self._rank = 0 + self._world_size = 1 + else: + self._rank = 0 + self._world_size = 1 + else: + eval_logger.warning( + "Passed an already-initialized model through `pretrained`, assuming single-process call." + ) + self._rank = 0 + self._world_size = 1 + + self.custom_prefix_token_id = prefix_token_id + if prefix_token_id is not None: + eval_logger.info( + f"Loglikelihood prefix token id used in evaluation: {self.prefix_token_id}" + ) + self.is_first_inference = True + + if HTSSampler is not None: + self.hts_sampler = HTSSampler(self.model, self.tokenizer, device=self.device) + eval_logger.info("HTSSampler initialized successfully.") + + # Copy all the property and helper methods from LLaDA2 + @property + def rank(self): + if hasattr(self, "_rank"): + return self._rank + if hasattr(self, "accelerator"): + return self.accelerator.local_process_index + return int(os.environ.get("LOCAL_RANK", 0)) + + @property + def world_size(self): + if hasattr(self, "_world_size"): + return self._world_size + if hasattr(self, "accelerator"): + return self.accelerator.num_processes + return int(os.environ.get("WORLD_SIZE", 1)) + + def _get_accelerate_args( + self, + parallelize: Optional[bool] = None, + device_map: Optional[str] = "auto", + max_memory_per_gpu: Optional[Union[int, str]] = None, + max_cpu_memory: Optional[Union[int, str]] = None, + offload_folder: Optional[str] = "./offload", + gpus: Optional[int] = None, + ) -> dict: + """Get accelerate arguments - same as LLaDA2""" + num_local_processes = int(os.environ.get("LOCAL_WORLD_SIZE", 1)) + if parallelize is None and gpus is not None and gpus > 1: + parallelize = True + args = {} + if parallelize: + max_memory_all_gpus = get_max_memory() + if "cpu" in max_memory_all_gpus: + del max_memory_all_gpus["cpu"] + max_memory_per_gpu_map = { + device_idx: max_memory_per_gpu for device_idx in range(len(max_memory_all_gpus)) + } if max_memory_per_gpu is not None else {k: v for k, v in max_memory_all_gpus.items()} + if hasattr(self, "accelerator"): + max_memory_per_gpu_map = { + k: v for k, v in max_memory_all_gpus.items() + if k % num_local_processes == self.accelerator.process_index % num_local_processes + } + args["max_memory"] = max_memory_per_gpu_map + args["device_map"] = "auto" + args["offload_folder"] = offload_folder + if max_cpu_memory is not None: + args["max_memory"]["cpu"] = max_cpu_memory + else: + args["device_map"] = {"": str(self.device)} + return args + + @property + def config(self): + return self._config + + @property + def model(self): + if hasattr(self, "accelerator"): + return self.accelerator.unwrap_model(self._model) + else: + return self._model + + @property + def eot_token_id(self): + return self.tokenizer.eos_token_id + + @property + def prefix_token_id(self): + if self.custom_prefix_token_id is not None: + return self.custom_prefix_token_id + if self.tokenizer.bos_token_id is not None: + return self.tokenizer.bos_token_id + return self.tokenizer.eos_token_id + + @property + def max_length(self): + if self._max_length: + return self._max_length + seqlen_config_attrs = ("n_positions", "max_position_embeddings", "n_ctx") + for attr in seqlen_config_attrs: + if hasattr(self.model.config, attr): + return getattr(self.model.config, attr) + if hasattr(self.tokenizer, "model_max_length"): + if self.tokenizer.model_max_length > 1e10: + return self._DEFAULT_MAX_LENGTH + return self.tokenizer.model_max_length + return self._DEFAULT_MAX_LENGTH + + @property + def max_gen_toks(self) -> int: + return 256 + + @property + def batch_size(self): + return self.batch_size_per_gpu + + @property + def device(self): + return self._device + + @property + def tokenizer_name(self) -> str: + return self.tokenizer.name_or_path.replace("/", "__") + + def _get_backend(self, config, backend, trust_remote_code): + """Get backend type - same as LLaDA2""" + assert backend in ["default", "causal", "seq2seq"] + if backend != "default": + self.backend = backend + eval_logger.info(f"Overrode HF model backend type, and using type '{self.backend}'") + else: + if getattr(config, "model_type") in MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES: + self.backend = "seq2seq" + elif getattr(self.config, "model_type") in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES: + self.backend = "causal" + else: + eval_logger.warning("HF model type is neither CausalLM nor Seq2SeqLM. Assuming CausalLM.") + self.backend = "causal" + + def _get_config(self, pretrained, revision, trust_remote_code, gguf_file): + """Get model config - same as LLaDA2""" + self._config = transformers.AutoConfig.from_pretrained( + pretrained, revision=revision, trust_remote_code=trust_remote_code + ) + + def _create_model(self, pretrained, revision, dtype, trust_remote_code, parallelize, + gpus, max_memory_per_gpu, max_cpu_memory, offload_folder, + peft, delta, autogptq, gptqmodel, gguf_file, **kwargs): + """Create model - same as LLaDA2""" + if autogptq or gptqmodel: + raise NotImplementedError("Quantization options are not implemented.") + model_dtype = get_dtype(dtype) + eval_logger.info(f"Loading model with dtype: {model_dtype}") + model_kwargs = kwargs if kwargs else {} + if not parallelize: + model_kwargs.update( + self._get_accelerate_args( + parallelize=parallelize, + gpus=gpus, + max_memory_per_gpu=max_memory_per_gpu, + max_cpu_memory=max_cpu_memory, + offload_folder=offload_folder + ) + ) + self._model = transformers.AutoModelForCausalLM.from_pretrained( + pretrained, revision=revision, torch_dtype=model_dtype, + trust_remote_code=trust_remote_code, **model_kwargs + ) + if peft: + from peft import PeftModel + eval_logger.info(f"Loading PEFT model from {peft}") + self._model = PeftModel.from_pretrained(self._model, peft, torch_dtype=model_dtype) + if not parallelize: + self._model = self._model.to(self.device) + self._model = self._model.to(torch.bfloat16) + self._model.eval() + + def _create_tokenizer(self, pretrained, tokenizer, revision, trust_remote_code, + use_fast_tokenizer, gguf_file, add_bos_token): + """Create tokenizer - same as LLaDA2""" + kwargs = { + "revision": revision, + "trust_remote_code": trust_remote_code, + "use_fast": use_fast_tokenizer + } + if add_bos_token: + kwargs["add_bos_token"] = True + if tokenizer: + if isinstance(tokenizer, str): + self.tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer, **kwargs) + else: + self.tokenizer = tokenizer + else: + model_name = pretrained if isinstance(pretrained, str) else self.model.name_or_path + self.tokenizer = transformers.AutoTokenizer.from_pretrained(model_name, **kwargs) + + def tok_encode(self, string, left_truncate_len=None, add_special_tokens=None): + """Tokenize string - same as LLaDA2""" + special_tokens_kwargs = {} + if add_special_tokens is None: + if self.backend == "causal": + special_tokens_kwargs["add_special_tokens"] = self.add_bos_token + else: + special_tokens_kwargs["add_special_tokens"] = add_special_tokens + encoding = self.tokenizer.encode(string, **special_tokens_kwargs) + if left_truncate_len: + encoding = encoding[-left_truncate_len:] + return encoding + + def tok_batch_encode(self, strings, padding_side="left", left_truncate_len=None, truncation=False): + """Batch tokenize - same as LLaDA2""" + old_padding_side = self.tokenizer.padding_side + self.tokenizer.padding_side = padding_side + add_special_tokens = {"add_special_tokens": self.add_bos_token} if self.backend == "causal" else {} + encoding = self.tokenizer( + strings, truncation=truncation, padding="longest", + return_tensors="pt", **add_special_tokens + ) + if left_truncate_len and encoding["input_ids"].size(1) > left_truncate_len: + eval_logger.warning(f"Left-truncating from {encoding['input_ids'].size(1)} to {left_truncate_len} tokens.") + encoding["input_ids"] = encoding["input_ids"][:, -left_truncate_len:] + encoding["attention_mask"] = encoding["attention_mask"][:, -left_truncate_len:] + self.tokenizer.padding_side = old_padding_side + return encoding["input_ids"].to(self.device), encoding["attention_mask"].to(self.device) + + def tok_decode(self, tokens, skip_special_tokens=False): + """Decode tokens - same as LLaDA2""" + return self.tokenizer.decode(tokens, skip_special_tokens=skip_special_tokens) + + def _model_call(self, inps, attn_mask=None, labels=None): + """Model forward call - same as LLaDA2""" + with torch.no_grad(): + if self.backend == "seq2seq": + return self.model(input_ids=inps, attention_mask=attn_mask, labels=labels).logits + else: + return self.model(inps, attention_mask=attn_mask).logits + + def _loglikelihood_tokens(self, requests, **kwargs) -> List[Tuple[float, bool]]: + raise NotImplementedError + + def loglikelihood_rolling( + self, requests: List[Instance], disable_tqdm: bool = False + ) -> List[float]: + raise NotImplementedError + + def loglikelihood(self, requests): + raise NotImplementedError + + def generate_until(self, requests: List[Instance]) -> List[str]: + """Generate until - adapted for LLaDA v1 """ + res = [] + gen_kwargs = requests[0].args[1] + use_hts = gen_kwargs.get("use_hts", False) + + realtime_output = gen_kwargs.get("realtime_output", "realtime_hts_results.jsonl") + baseline_realtime_output = gen_kwargs.get("realtime_output", "realtime_baseline_results.jsonl") + + if not use_hts and "realtime_output" not in gen_kwargs: + baseline_realtime_output = "realtime_baseline_results.jsonl" + + if not use_hts: + bar = tqdm(total=len(requests), disable=(self.rank != 0), desc="Running Baseline (LLaDA v1)") + + for req in requests: + prompt_text = req.args[0] + local_gen_kwargs = req.args[1] if len(req.args) > 1 else {} + + context_enc, _ = self.tok_batch_encode([prompt_text]) + + final_codes, stats = self.hts_sampler.generate_hts( + prompt_text=prompt_text, + input_ids=context_enc, + initial_N=1, + final_K=1, + hts_survivor_k=1, + hts_mode=False, + hts_start_pct=0.0, + hts_end_pct=0.0, + decay_factor=1.5, + pruning_interval=0, + reward_mode="confidence", + task_type=local_gen_kwargs.get("task_type", "code"), + steps=int(local_gen_kwargs.get("steps", 32)), + gen_length=int(local_gen_kwargs.get("gen_length", 512)), + block_length=int(local_gen_kwargs.get("block_length", 32)), + temperature=float(local_gen_kwargs.get("temperature", 0.0)), + top_p=float(local_gen_kwargs.get("top_p", 0.95)), + top_k=local_gen_kwargs.get("top_k", None), + threshold=float(local_gen_kwargs.get("threshold", 0.85)), + mask_id=self.mask_id, + eos_id=self.eot_token_id, + until=local_gen_kwargs.get("until", []), + ) + + processed_codes = [] + for code in final_codes: + code = code.strip() + if not self.escape_until: + until_terms = local_gen_kwargs.get("until", []) + for term in until_terms: + if len(term) > 0 and term in code: + code = code.split(term)[0] + processed_codes.append(code) + + final_choice = processed_codes[0] if processed_codes else "" + res.append(final_choice) + + target_val = getattr(req, "target", None) + if target_val is None or target_val == "N/A": + if "test" in req.doc and "entry_point" in req.doc: + target_val = req.doc["test"] + "\ncheck(" + req.doc["entry_point"] + ")" + else: + target_val = req.doc.get("answer", req.doc.get("solution", "N/A")) + + output_dir = os.path.dirname(baseline_realtime_output) + if output_dir: + os.makedirs(output_dir, exist_ok=True) + with open(baseline_realtime_output, "a", encoding="utf-8") as f: + all_resps = [[code] for code in processed_codes] + output_data = { + "doc": req.doc, + "target": target_val, + "resps": all_resps, + "prompt": prompt_text, + "entropy_history": stats.get("entropy_history", []), + "pruning_history": stats.get("pruning_history", []), + "final_scores": stats.get("final_scores", []), + "all_trajectories": stats.get("all_trajectories", []), + "nfe": stats.get("nfe", 0), + "first_block_nfe": stats.get("first_block_nfe", 0), + "svf_calls": stats.get("svf_calls", 0), + "total_steps": stats.get("total_steps", 0), + "num_gen_blocks": stats.get("num_gen_blocks", []), + "steps_per_block": stats.get("steps_per_block", []) + } + f.write(json.dumps(output_data, ensure_ascii=False) + "\n") + f.flush() + + bar.update(1) + bar.close() + + else: + bar = tqdm(total=len(requests), disable=(self.rank != 0), desc="Running HTS+SVF (LLaDA v1)") + for req in requests: + prompt_text = req.args[0] + local_gen_kwargs = req.args[1] if len(req.args) > 1 else {} + context_enc, _ = self.tok_batch_encode([prompt_text]) + + p_interval = int(local_gen_kwargs.get("pruning_interval", 0)) + + final_codes, stats = self.hts_sampler.generate_hts( + prompt_text=prompt_text, + input_ids=context_enc, + initial_N=int(local_gen_kwargs.get("hts_N", 4)), + final_K=int(local_gen_kwargs.get("final_K", 1)), + hts_survivor_k=int(local_gen_kwargs.get("hts_survivor_k", 4)), + hts_mode=local_gen_kwargs.get("hts_mode", True), + hts_start_pct=float(local_gen_kwargs.get("hts_start_pct", 0.1)), + hts_end_pct=float(local_gen_kwargs.get("hts_end_pct", 0.6)), + decay_factor=float(local_gen_kwargs.get("decay_factor", 1.5)), + pruning_interval=p_interval, + reward_mode=local_gen_kwargs.get("reward_mode", "svf"), + task_type=local_gen_kwargs.get("task_type", "code"), + steps=int(local_gen_kwargs.get("steps", 32)), + gen_length=int(local_gen_kwargs.get("gen_length", 512)), + block_length=int(local_gen_kwargs.get("block_length", 32)), + temperature=float(local_gen_kwargs.get("temperature", 0.7)), + top_p=float(local_gen_kwargs.get("top_p", 0.95)), + top_k=local_gen_kwargs.get("top_k", None), + threshold=float(local_gen_kwargs.get("threshold", 0.85)), + mask_id=self.mask_id, + eos_id=self.eot_token_id, + until=local_gen_kwargs.get("until", []), + ) + + processed_codes = [] + for code in final_codes: + code = code.strip() + if not self.escape_until: + until_terms = local_gen_kwargs.get("until", []) + for term in until_terms: + if len(term) > 0 and term in code: + code = code.split(term)[0] + processed_codes.append(code) + + final_choice = processed_codes[0] + res.append(final_choice) + + target_val = getattr(req, "target", None) + if target_val is None or target_val == "N/A": + if "test" in req.doc and "entry_point" in req.doc: + target_val = req.doc["test"] + "\ncheck(" + req.doc["entry_point"] + ")" + else: + target_val = req.doc.get("answer", req.doc.get("solution", "N/A")) + + output_dir = os.path.dirname(realtime_output) + if output_dir: + os.makedirs(output_dir, exist_ok=True) + with open(realtime_output, "a", encoding="utf-8") as f: + all_resps = [[code] for code in processed_codes] + output_data = { + "doc": req.doc, + "target": target_val, + "resps": all_resps, + "prompt": prompt_text, + "entropy_history": stats.get("entropy_history", []), + "pruning_history": stats.get("pruning_history", []), + "final_scores": stats.get("final_scores", []), + "all_trajectories": stats.get("all_trajectories", []), + "nfe": stats.get("nfe", 0), + "first_block_nfe": stats.get("first_block_nfe", 0), + "svf_calls": stats.get("svf_calls", 0), + "total_steps": stats.get("total_steps", 0), + "num_gen_blocks": stats.get("num_gen_blocks", []), + "steps_per_block": stats.get("steps_per_block", []) + } + f.write(json.dumps(output_data, ensure_ascii=False) + "\n") + f.flush() + + bar.update(1) + bar.close() + + return res + + def apply_chat_template( + self, chat_history: List[Dict[str, str]], add_generation_prompt: bool = True + ) -> str: + """Apply chat template - same as LLaDA2""" + chat_templated = self.tokenizer.apply_chat_template( + chat_history, tokenize=False, add_generation_prompt=add_generation_prompt + ) + if self.assistant_prefix: + chat_templated += self.assistant_prefix + return chat_templated diff --git a/Prism/LLaDA/LLaDA_Baseline/dllm_eval/models/__init__.py b/Prism/LLaDA/LLaDA_Baseline/dllm_eval/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..89a01de1866ac6d741c08d46d1a3ad8857c902ee --- /dev/null +++ b/Prism/LLaDA/LLaDA_Baseline/dllm_eval/models/__init__.py @@ -0,0 +1,19 @@ +from . import ( + LLaDA, + huggingface, +) +# from .configuration_llada import LLaDAConfig +# from .modeling_llada import LLaDAModelLM + + +try: + # enable hf hub transfer if available + import hf_transfer # type: ignore # noqa + import huggingface_hub.constants # type: ignore + + huggingface_hub.constants.HF_HUB_ENABLE_HF_TRANSFER = True +except ImportError: + pass + + +# __all__ = ['LLaDAConfig', 'LLaDAModelLM'] diff --git a/Prism/LLaDA/LLaDA_Baseline/dllm_eval/models/dummy.py b/Prism/LLaDA/LLaDA_Baseline/dllm_eval/models/dummy.py new file mode 100644 index 0000000000000000000000000000000000000000..4702a36cb29809c9dd08c516b99e74e71ffcc166 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Baseline/dllm_eval/models/dummy.py @@ -0,0 +1,41 @@ +import random + +from tqdm import tqdm + +from dllm_eval.api.model import LM +from dllm_eval.api.registry import register_model + + +@register_model("dummy") +class DummyLM(LM): + def __init__(self) -> None: + super().__init__() + + @classmethod + def create_from_arg_string(cls, arg_string, additional_config=None): + return cls() + + def loglikelihood(self, requests, disable_tqdm: bool = False): + res = [] + + for _ in tqdm(requests, disable=disable_tqdm): + res.append((-random.random(), False)) + + return res + + def generate_until(self, requests, disable_tqdm: bool = False): + res = [] + + for request in tqdm(requests, disable=disable_tqdm): + res.append("lol") + assert request.arguments[0].strip() != "" + + return res + + def loglikelihood_rolling(self, requests, disable_tqdm: bool = False): + res = [] + + for _ in tqdm(requests, disable=disable_tqdm): + res.append(-random.random()) + + return res diff --git a/Prism/LLaDA/LLaDA_Baseline/dllm_eval/models/hts_sampler.py b/Prism/LLaDA/LLaDA_Baseline/dllm_eval/models/hts_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..7c4a0fdfb29138f5c80fb3596f44375075f164cc --- /dev/null +++ b/Prism/LLaDA/LLaDA_Baseline/dllm_eval/models/hts_sampler.py @@ -0,0 +1,315 @@ +import torch +import torch.nn.functional as F +import numpy as np +from .verifier import CodeVerifier +import logging +import re +import math + +logger = logging.getLogger(__name__) + +class HTSSampler: + def __init__(self, model, tokenizer, device="cuda"): + self.model = model + self.tokenizer = tokenizer + self.device = device + self.verifier = CodeVerifier(model, tokenizer, device) + + def _get_num_transfer_tokens(self, block_length, steps): + if steps == 0: return torch.tensor([], dtype=torch.int64) + base = block_length // steps + remainder = block_length % steps + num_transfer_tokens = torch.full((steps,), base, dtype=torch.int64) + num_transfer_tokens[:remainder] += 1 + return num_transfer_tokens + + def _sample_with_temperature(self, logits, temperature, top_k, top_p): + logits = logits.to(torch.float32) + + orig_probs = torch.softmax(logits, dim=-1) + x0_p, _ = torch.max(orig_probs, dim=-1) + + if temperature > 0.0: + noise = torch.rand_like(logits, dtype=torch.float32) + gumbel_noise = -torch.log(-torch.log(noise + 1e-10) + 1e-10) + logits = logits / temperature + gumbel_noise + + if top_k is not None and top_k > 0: + indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] + logits[indices_to_remove] = -float('Inf') + + x0 = torch.argmax(logits, dim=-1) + + return x0, x0_p + + def _safe_scalar(self, val): + if isinstance(val, torch.Tensor): + if val.numel() > 1: return val.mean().item() + return val.item() + return float(val) + + def _analyze_structure(self, text, task_type="code"): + score = 0.0 + stripped = text.strip() + if task_type == "code": + if len(stripped) < 5: return -0.1 + keywords = ["return", "print", "yield", "lambda", "class ", "def "] + if any(k in stripped for k in keywords): score += 0.05 + if ":" in stripped: score += 0.02 + if " " in text: score += 0.03 + elif task_type == "math": + if "\\boxed{" in stripped: score += 0.1 + if "The answer is" in stripped: score += 0.05 + if len(stripped) < 10: return -0.1 + if "Step" in stripped and stripped.count("Step") > 15: score -= 0.2 + return score + + def _chunked_forward(self, x, chunk_size=96, slice_indices=None): + total_batch = x.shape[0] + logits_list = [] + for i in range(0, total_batch, chunk_size): + end_idx = min(i + chunk_size, total_batch) + sub_x = x[i:end_idx] + sub_mask = torch.ones_like(sub_x, device=self.device) + with torch.no_grad(): + outputs = self.model(input_ids=sub_x, attention_mask=sub_mask) + sub_logits = outputs.logits + if slice_indices is not None: + s_start, s_end = slice_indices + sub_logits = sub_logits[:, s_start:s_end, :] + logits_list.append(sub_logits.detach().clone()) + return torch.cat(logits_list, dim=0) + + def _branch_and_resample(self, x, conf_scores, survivor_indices, target_width, mask_id, + prompt_length, resample_window=5, task_type="code"): + num_survivors = len(survivor_indices) + if num_survivors == 0: return x[:target_width].clone(), conf_scores[:target_width].clone() + + if task_type == "math": resample_window = 6 + elif task_type == "reasoning": resample_window = 6 + elif task_type == "code": resample_window = 6 + + base_repeat = target_width // num_survivors + remainder = target_width % num_survivors + new_x_list = [] + new_conf_list = [] + + for i in range(num_survivors): + count = base_repeat + (1 if i < remainder else 0) + if count == 0: continue + + survivor_x = x[survivor_indices[i]] + survivor_conf = conf_scores[survivor_indices[i]] + + new_x_list.append(survivor_x.unsqueeze(0)) + new_conf_list.append(survivor_conf.unsqueeze(0)) + + if count > 1: + gen_part = survivor_x[prompt_length:] + gen_conf = survivor_conf[prompt_length:] + non_mask_indices = (gen_part != mask_id).nonzero(as_tuple=True)[0] + + for _ in range(count - 1): + perturbed_x = survivor_x.clone() + perturbed_conf = survivor_conf.clone() + + if len(non_mask_indices) > 0: + pool_size = min(resample_window * 2, len(non_mask_indices)) + current_token_confs = gen_conf[non_mask_indices] + + _, candidate_indices = torch.topk(current_token_confs, k=pool_size, largest=False) + + num_to_perturb = min(resample_window, pool_size) + rand_indices = torch.randperm(pool_size, device=self.device)[:num_to_perturb] + selected_sub_indices = candidate_indices[rand_indices] + + target_indices_in_x = prompt_length + non_mask_indices[selected_sub_indices] + perturbed_x[target_indices_in_x] = mask_id + perturbed_conf[target_indices_in_x] = 0.0 + + new_x_list.append(perturbed_x.unsqueeze(0)) + new_conf_list.append(perturbed_conf.unsqueeze(0)) + + return torch.cat(new_x_list, dim=0), torch.cat(new_conf_list, dim=0) + + @torch.no_grad() + def generate_hts(self, prompt_text, input_ids, problem_data=None, + initial_N=1, final_K=1, survivor_K=None, + prune_step_pct=0.0, reward_mode="confidence", + temperature=0.7, block_length=32, steps=64, gen_length=1024, + top_p=0.95, top_k=None, minimal_topk=1, threshold=0.9, + eos_id=156892, mask_id=156895, + hts_mode=False, hts_start_pct=0.1, hts_end_pct=0.6, decay_factor=1.5, + hts_survivor_k=4, task_type="code", until=None, pruning_interval=0): + + input_ids = input_ids.to(self.device) + if input_ids.shape[0] == 1: input_ids = input_ids.repeat(initial_N, 1) + + schedule_map = {} + ts_start, tr_end = 0, 0 + if not hts_mode: + final_K_list = [final_K] if not isinstance(final_K, list) else final_K + prune_pct_list = [prune_step_pct] if not isinstance(prune_step_pct, list) else prune_step_pct + survivor_K_list = final_K_list if survivor_K is None else ([survivor_K] if not isinstance(survivor_K, list) else survivor_K) + if len(survivor_K_list) < len(final_K_list): survivor_K_list.extend(final_K_list[len(survivor_K_list):]) + for pct, width, parents in zip(prune_pct_list, final_K_list, survivor_K_list): + if pct > 0: + s = int(steps * pct) + schedule_map[s] = (width, parents) + else: + final_K_list = [final_K] if not isinstance(final_K, int) else [final_K] + ts_start, tr_end = int(steps * hts_start_pct), int(steps * hts_end_pct) + + + prompt_length = input_ids.shape[1] + num_blocks = (gen_length + block_length - 1) // block_length + total_length = prompt_length + num_blocks * block_length + + x = torch.full((initial_N, total_length), mask_id, dtype=torch.long, device=self.device) + x[:, :prompt_length] = input_ids.clone() + + conf_scores = torch.zeros((initial_N, total_length), dtype=torch.float32, device=self.device) + conf_scores[:, :prompt_length] = 1.0 + + prefill_blocks = 0 + num_gen_blocks = num_blocks + current_bsz = initial_N + + next_allowed_pruning_step = ts_start if hts_mode else 0 + + stats = { + "initial_n": initial_N, "final_k": final_K_list[-1], + "pruning_history": [], "entropy_history": [], "nfe": 0.0, + "svf_calls": 0, "final_scores": [], "total_steps": steps, + "first_block_nfe": 0.0, "num_gen_blocks": [], "steps_per_block": [] + } + + for num_block in range(num_gen_blocks): + stats["num_gen_blocks"].append(num_block) + + window_start = prompt_length + num_block * block_length + window_end = window_start + block_length + + schedule = self._get_num_transfer_tokens(block_length, steps) + + steps_this_block = 0 + for step in range(steps): + steps_this_block += 1 + cur_full_x = x[:current_bsz, :] + + perform_pruning = False + num_parents_to_select = 0 + + if hts_mode and step >= next_allowed_pruning_step and step < tr_end: + target_width = max(final_K_list[-1], math.ceil(initial_N * (decay_factor ** -(step - ts_start)))) + if current_bsz > target_width: + perform_pruning = True + num_parents_to_select = hts_survivor_k + elif not hts_mode and step in schedule_map: + target_width, num_parents_to_select = schedule_map[step] + if current_bsz > target_width: perform_pruning = True + + if perform_pruning: + stats["nfe"] += current_bsz + if num_block == 0: stats["first_block_nfe"] += current_bsz + stats["svf_calls"] += current_bsz + + gen_logits = self._chunked_forward(cur_full_x, chunk_size=64, slice_indices=(prompt_length, total_length)) + rough_ids = torch.argmax(gen_logits, dim=-1) + rough_codes_snippet = self.tokenizer.batch_decode(rough_ids, skip_special_tokens=True) + candidates = [] + for i in range(current_bsz): + full_code = rough_codes_snippet[i] + s = self._safe_scalar(self.verifier.get_reward(prompt_text, full_code, mode=reward_mode, problem_data=problem_data, current_logits=gen_logits[i] if reward_mode != "svf" else None, task_type=task_type)) + s += self._analyze_structure(full_code, task_type=task_type) + clean_content = full_code.strip().replace(" ", "").replace("\n", "") + candidates.append({'score': s, 'idx': i, 'key': hash(clean_content[:200] + clean_content[-200:])}) + + stats["pruning_history"].append({"step": step, "scores": [c['score'] for c in candidates]}) + candidates.sort(key=lambda x: x['score'], reverse=True) + + selected_indices, seen_keys = [], set() + for cand in candidates: + if len(selected_indices) >= num_parents_to_select: break + if cand['key'] not in seen_keys: + selected_indices.append(cand['idx']); seen_keys.add(cand['key']) + + if len(selected_indices) < num_parents_to_select: + for cand in candidates: + if len(selected_indices) >= num_parents_to_select: break + if cand['idx'] not in selected_indices: selected_indices.append(cand['idx']) + + top_indices = torch.tensor(selected_indices, device=self.device) + x, conf_scores = self._branch_and_resample(x, conf_scores, top_indices, target_width, mask_id, prompt_length, task_type=task_type) + + current_bsz = target_width + cur_full_x = x[:current_bsz, :] + next_allowed_pruning_step = step + 1 + pruning_interval + + stats["nfe"] += current_bsz + if num_block == 0: stats["first_block_nfe"] += current_bsz + + active_logits = self._chunked_forward(cur_full_x, chunk_size=32, slice_indices=(window_start, window_end)) + active_logits[:, :, eos_id] = -1e10 + + x0, x0_p = self._sample_with_temperature(active_logits, temperature, top_k, top_p) + + active_mask = x[:current_bsz, window_start:window_end] == mask_id + + num_transfer = schedule[step].item() + confidence = torch.where(active_mask, x0_p, -torch.inf) + transfer_idx = torch.zeros_like(x0, dtype=torch.bool) + + for b in range(current_bsz): + mask_count = active_mask[b].sum().item() + if mask_count > 0: + k_transfer = min(num_transfer, mask_count) + active_indices = torch.where(active_mask[b])[0] + high_conf_mask = (confidence[b] > threshold) & active_mask[b] + if high_conf_mask.sum().item() >= k_transfer: + conf_indices = torch.where(high_conf_mask)[0] + transfer_idx[b, conf_indices] = True + else: + _, topk_indices = torch.topk(confidence[b][active_indices], k=min(k_transfer, len(active_indices))) + transfer_idx[b, active_indices[topk_indices]] = True + + if transfer_idx.any(): + x[:current_bsz, window_start:window_end][transfer_idx] = x0[transfer_idx] + conf_scores[:current_bsz, window_start:window_end][transfer_idx] = x0_p[transfer_idx] + + if task_type in ["math", "reasoning"]: + for b in range(current_bsz): + text_snippet = self.tokenizer.decode(x[b, prompt_length:window_end], skip_special_tokens=True) + should_stop = False + if task_type == "reasoning" and ("###" in text_snippet): should_stop = True + if task_type == "math" and ("\\boxed{" in text_snippet and "}" in text_snippet.split("\\boxed{")[-1]): should_stop = True + + if should_stop: + after_mask = (x[b, window_start:total_length] == mask_id) + x[b, window_start:total_length][after_mask] = eos_id + + stats["steps_per_block"].append(steps_this_block) + x = x[:current_bsz] + + stats["nfe"] = int(round(stats["nfe"])) + stats["first_block_nfe"] = int(round(stats["first_block_nfe"])) + + final_gen_tokens = x[:current_bsz, prompt_length:] + final_codes = self.tokenizer.batch_decode(final_gen_tokens, skip_special_tokens=True) + final_candidates = [] + + stats["svf_calls"] += len(final_codes) + for i in range(len(final_codes)): + txt = final_codes[i] + if until: + for term in until: + if term in txt: txt = txt.split(term)[0] + s = self._safe_scalar(self.verifier.get_reward(prompt_text, txt, mode=reward_mode, task_type=task_type)) + s += self._analyze_structure(txt, task_type) + final_candidates.append({'resp': txt, 'score': s}) + + final_candidates.sort(key=lambda x: x['score'], reverse=True) + stats["final_scores"] = [c['score'] for c in final_candidates] + stats["all_trajectories"] = [{"rank": i+1, "resp": c['resp'], "score": c['score']} for i, c in enumerate(final_candidates)] + + return [c['resp'] for c in final_candidates], stats \ No newline at end of file diff --git a/Prism/LLaDA/LLaDA_Baseline/dllm_eval/models/huggingface.py b/Prism/LLaDA/LLaDA_Baseline/dllm_eval/models/huggingface.py new file mode 100644 index 0000000000000000000000000000000000000000..bf6e1e99e20aeed5b20f7cd2d7a8f9b76155330a --- /dev/null +++ b/Prism/LLaDA/LLaDA_Baseline/dllm_eval/models/huggingface.py @@ -0,0 +1,1489 @@ +import copy +import logging +import os +from datetime import timedelta +from pathlib import Path +from typing import Any, Dict, List, Literal, Optional, Tuple, Union + +import jinja2 +import torch +import torch.nn.functional as F +import transformers +from accelerate import ( + Accelerator, + InitProcessGroupKwargs, + find_executable_batch_size, +) +from accelerate.utils import get_max_memory +from huggingface_hub import HfApi +from packaging import version +from peft import PeftModel +from peft import __version__ as PEFT_VERSION +from tqdm import tqdm +from transformers.models.auto.modeling_auto import ( + MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, + MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES, +) + +from dllm_eval import utils +from dllm_eval.api.instance import Instance +from dllm_eval.api.model import TemplateLM +from dllm_eval.api.registry import register_model +from dllm_eval.models.utils import ( + Collator, + clear_torch_cache, + configure_pad_token, + get_dtype, + handle_stop_sequences, + pad_and_concat, + stop_sequences_criteria, +) + + +eval_logger = logging.getLogger(__name__) + + +@register_model("hf-auto", "hf", "huggingface") +class HFLM(TemplateLM): + """ + An abstracted Huggingface model class. Enables usage with both models of + `transformers.AutoModelForCausalLM` and `transformers.AutoModelForSeq2SeqLM` classes. + + Supports data-parallel multi-GPU with HF Accelerate. + """ + + AUTO_MODEL_CLASS = None + _DEFAULT_MAX_LENGTH = 2048 + + def __init__( + self, + pretrained: Union[str, transformers.PreTrainedModel], + backend: Literal["default", "causal", "seq2seq"] = "default", + # override whether the model should be treated as decoder-only (causal) or encoder-decoder (seq2seq) + revision: Optional[str] = "main", + subfolder: str = "", + tokenizer: Optional[ + Union[ + str, + transformers.PreTrainedTokenizer, + transformers.PreTrainedTokenizerFast, + ] + ] = None, + truncation: Optional[bool] = False, + logits_cache: bool = True, + max_length: Optional[int] = None, + device: Optional[str] = "cuda", + dtype: Optional[Union[str, torch.dtype]] = "auto", + softmax_dtype: Optional[Union[str, torch.dtype]] = None, + batch_size: Optional[Union[int, str]] = 1, + max_batch_size: Optional[int] = 64, + trust_remote_code: Optional[bool] = False, + use_fast_tokenizer: Optional[bool] = True, + add_bos_token: Optional[bool] = False, + prefix_token_id: Optional[int] = None, + # arguments used for splitting a model across GPUs naively. + # only used if `parallelize=True`. + parallelize: Optional[bool] = False, + max_memory_per_gpu: Optional[Union[int, str]] = None, + max_cpu_memory: Optional[Union[int, str]] = None, + offload_folder: Optional[Union[str, os.PathLike]] = "./offload", + # PEFT, delta weights and quantization options + peft: Optional[str] = None, + delta: Optional[str] = None, + autogptq: Optional[Union[bool, str]] = False, + gptqmodel: Optional[bool] = False, + gguf_file: Optional[str] = None, + **kwargs, + ) -> None: + super().__init__() + # optionally: take in an already-initialized transformers.PreTrainedModel + if not isinstance(pretrained, str): + eval_logger.warning( + "`pretrained` model kwarg is not of type `str`. Many other model arguments may be ignored. Please do not launch via accelerate or use `parallelize=True` if passing an existing model this way." + ) + assert not parallelize, ( + "`parallelize=True` is not compatible with passing pre-initialized model to `pretrained`" + ) + self._model = pretrained + self._device = self._model.device + self._config = self._model.config + gpus = 0 + + else: + assert isinstance(device, str) + assert isinstance(pretrained, str) + assert isinstance(batch_size, (int, str)) + + gpus = torch.cuda.device_count() + accelerator_kwargs = InitProcessGroupKwargs(timeout=timedelta(weeks=52)) + accelerator = Accelerator(kwargs_handlers=[accelerator_kwargs]) + if accelerator.num_processes > 1: + self.accelerator = accelerator + + if "npu" in accelerator.device.type: + gpus = torch.npu.device_count() + + # using one process with no model parallelism + if not (parallelize or accelerator.num_processes > 1): + # use user-passed device + device_list = set( + ["cuda", "cpu"] + + [f"cuda:{i}" for i in range(gpus)] + + ["mps", "mps:0"] + + [f"npu:{i}" for i in range(gpus)] + ) + if device and device in device_list: + self._device = torch.device(device) + eval_logger.info(f"Using device '{device}'") + if device in ("mps", "mps:0") and version.parse( + torch.__version__ + ) < version.parse("2.1"): + raise RuntimeError( + f"mps requires torch >= 2.1. You have {torch.__version__}" + ) + else: + eval_logger.info("Device not specified") + eval_logger.info(f"Cuda Available? {torch.cuda.is_available()}") + self._device = ( + torch.device("cuda") + if torch.cuda.is_available() + else torch.device("cpu") + ) + else: # Parallelism managed by accelerate + if device != "cuda": + eval_logger.info( + f"Using `accelerate launch` or `parallelize=True`, device '{device}' will be overridden when placing model." + ) + # TODO: include in warning that `load_in_8bit` etc. affect this too + self._device = ( + self.accelerator.device + if hasattr(self, "accelerator") + else torch.device(device) + ) + + revision = str(revision) # cast to string if not already one + + self._get_config( + pretrained, + revision=revision, + trust_remote_code=trust_remote_code, + gguf_file=gguf_file, + subfolder=subfolder, + ) + + # determine which of 'causal' and 'seq2seq' backends to use for HF models + self._get_backend( + config=self.config, backend=backend, trust_remote_code=trust_remote_code + ) + + # load tokenizer so we know tokenizer vocabulary size before loading model and PEFT + self._create_tokenizer( + pretrained, + tokenizer, + revision=revision, + subfolder=subfolder, + trust_remote_code=trust_remote_code, + use_fast_tokenizer=use_fast_tokenizer, + gguf_file=gguf_file, + add_bos_token=add_bos_token, + ) + + # if we passed `pretrained` as a string, initialize our model now + if isinstance(pretrained, str): + self._create_model( + pretrained=pretrained, + revision=revision, + dtype=dtype, + trust_remote_code=trust_remote_code, + parallelize=parallelize, + gpus=gpus, + max_memory_per_gpu=max_memory_per_gpu, + max_cpu_memory=max_cpu_memory, + offload_folder=offload_folder, + peft=peft, + delta=delta, + autogptq=autogptq, + gptqmodel=gptqmodel, + gguf_file=gguf_file, + quantization_config=getattr(self.config, "quantization_config", None), + subfolder=subfolder, + **kwargs, + ) + + # access self._model through self.model property outside this method + if isinstance(self.model, torch.nn.Module): + self.model.eval() + self.model.tie_weights() + + self.truncation = truncation + self.logits_cache = logits_cache + self.vocab_size = self.tokenizer.vocab_size + # select (or create) a pad token to use + self.tokenizer = configure_pad_token(self.tokenizer, model_config=self.config) + + self.add_bos_token = add_bos_token + if "gemma" in getattr(self.config, "model_type", ""): + self.add_bos_token = True + eval_logger.info( + f"Model type is '{self.config.model_type}', part of the Gemma family--a BOS token will be used as Gemma underperforms without it." + ) + + self._max_length = max_length + self.pretrained = pretrained + self.delta = delta + self.peft = peft + self.revision = revision + self.batch_schedule = 1 + self.batch_sizes = {} + self.max_batch_size = max_batch_size + self.softmax_dtype = ( + get_dtype(softmax_dtype) if softmax_dtype is not None else None + ) + + if str(batch_size).startswith("auto"): + batch_size = batch_size.split(":") + self.batch_size_per_gpu = batch_size[0] + self.batch_schedule = float(batch_size[1]) if len(batch_size) > 1 else 1 + else: + self.batch_size_per_gpu = int(batch_size) + + if isinstance(pretrained, str): + if gpus >= 1 or str(self.device) == "mps": + # TODO: can remove this whole snippet except in the mps case, perhaps? + if not (parallelize or autogptq or hasattr(self, "accelerator")): + # place model onto device requested manually, + # if not using HF Accelerate or device_map + # or any other option that preloads model onto device + try: + self.model.to(self.device) + except ValueError: + eval_logger.debug( + "Failed to place model onto specified device. This may be because the model is quantized via `bitsandbytes` or `device_map` is provided. If the desired GPU is being used, this message is safe to ignore." + ) + # multigpu data-parallel support when launched with accelerate + if gpus > 1: + if accelerator.num_processes > 1: + if parallelize: + eval_logger.warning( + "You are both using a HF Accelerate `device_map` (`--model_args parallelize=True`) and launching via `accelerate launch`. This will attempt to do model and data parallelism depending on the resources available." + ) + elif gpus > accelerator.num_processes: + eval_logger.warning( + "WARNING: The number of total system GPUs does not match the number of spawned processes. " + "If you would like to use data parallelism, please launch the script " + "with 'accelerate launch *script*'. " + f"Current run will proceed with {accelerator.num_processes} devices." + ) + if self.accelerator.is_local_main_process: + eval_logger.info( + f"Using {gpus} devices with data parallelism" + ) + + self._device = torch.device(f"{accelerator.device}") + self.accelerator = accelerator + + self._rank = self.accelerator.local_process_index + self._world_size = self.accelerator.num_processes + else: + # if we aren't launching via accelerate, ditch + self._rank = 0 + self._world_size = 1 + else: + # if a PreTrainedModel was passed into HFLM, we forgo distributed setup. + eval_logger.warning( + "Passed an already-initialized model through `pretrained`, assuming single-process call to evaluate() or custom distributed integration" + ) + self._rank = 0 + self._world_size = 1 + + self.custom_prefix_token_id = prefix_token_id + if prefix_token_id is not None: + eval_logger.info( + f"Loglikelihood prefix token id used in evaluation: {self.prefix_token_id}" + ) + + def _get_accelerate_args( + self, + parallelize: Optional[bool] = None, + device_map: Optional[str] = "auto", + max_memory_per_gpu: Optional[Union[int, str]] = None, + max_cpu_memory: Optional[Union[int, str]] = None, + offload_folder: Optional[str] = "./offload", + gpus: Optional[int] = None, + ) -> dict: + """Returns the kwargs needed to apply `accelerate` in `AutoModel.from_pretrained`.""" + num_local_processes = int(os.environ.get("LOCAL_WORLD_SIZE", 1)) + num_machines = int(os.environ.get("WORLD_SIZE", 0)) // num_local_processes + if ( + num_machines == 0 + and hasattr(self, "accelerator") + and self.accelerator is not None + ): + eval_logger.info( + "We are not in a distributed setting for accelerate. Setting model_parallel to False." + ) + parallelize = False + + if parallelize is None: + # If parallelism is unset by the user, we automatically assign model parallelism + # if enough extra GPUs are available + max_memory_all_gpus = get_max_memory() + # We just want gpu, not cpu, max memory + if "cpu" in max_memory_all_gpus: + del max_memory_all_gpus["cpu"] + parallelize = bool(num_local_processes < len(max_memory_all_gpus)) + eval_logger.info( + f"Setting model parallel to {parallelize} since " + f"the number of local processes is {num_local_processes} " + f"and the number of GPUs is {len(max_memory_all_gpus)}" + ) + + args = {} + if parallelize: # Model parallelism will be used + max_memory = {} + if max_memory_per_gpu is not None: # Using the provided memory requirements + max_memory_per_gpu_map = { + device_idx: max_memory_per_gpu for device_idx in range(gpus) + } + else: # Estimating the possible memory requirements + max_memory_all_gpus = get_max_memory() + if "cpu" in max_memory_all_gpus: + del max_memory_all_gpus["cpu"] + if not hasattr(self, "accelerator"): + max_memory_per_gpu_map = { + k: v for k, v in max_memory_all_gpus.items() + } + else: + # use only 1 / num_processes of the GPUs if we are running under accelerate launch + max_memory_per_gpu_map = { + k: v + for k, v in max_memory_all_gpus.items() + if k % num_local_processes + == (self.accelerator.process_index % num_local_processes) + } + args["max_memory"] = max_memory_per_gpu_map + args["device_map"] = "auto" if device_map is None else device_map + eval_logger.info( + f"Model parallel was set to True, setting max memory per GPU to {max_memory_per_gpu_map} and device map to {args.get('device_map')}" + ) + + if max_cpu_memory is not None: + max_memory["cpu"] = max_cpu_memory + + args["offload_folder"] = offload_folder + elif ( + device_map is None + ): # No model parallelism, we use the default provided device for our model + if hasattr(self, "accelerator"): + device_map = {"": f"{self.accelerator.device}"} + else: + device_map = {"": str(self.device)} + args["max_memory"] = None + args["device_map"] = device_map + eval_logger.info( + f"Model parallel was set to False, max memory was not set, and device map was set to {device_map}" + ) + else: + args["max_memory"] = None + args["device_map"] = None + eval_logger.info("Model parallel was set to False.") + + return args + + @property + def config(self): + # return the associated transformers.AutoConfig for the given pretrained model. + return self._config + + @property + def model(self): + # returns the model, unwrapping it if using Accelerate + if hasattr(self, "accelerator"): + return self.accelerator.unwrap_model(self._model) + else: + return self._model + + @property + def eot_token_id(self): + # we use EOT because end of *text* is more accurate for what we're doing than end of *sentence* + return self.tokenizer.eos_token_id + + @property + def prefix_token_id(self): + # it is used as prefix for loglikelihood + if self.custom_prefix_token_id is not None: + return self.custom_prefix_token_id + if self.tokenizer.bos_token_id is not None: + return self.tokenizer.bos_token_id + return self.tokenizer.eos_token_id + + @property + def max_length(self): + if self._max_length: # if max length manually set, return it + return self._max_length + seqlen_config_attrs = ("n_positions", "max_position_embeddings", "n_ctx") + for attr in seqlen_config_attrs: + if hasattr(self.model.config, attr): + return getattr(self.model.config, attr) + if hasattr(self.tokenizer, "model_max_length"): + if self.tokenizer.model_max_length == 1000000000000000019884624838656: + return self._DEFAULT_MAX_LENGTH + return self.tokenizer.model_max_length + return self._DEFAULT_MAX_LENGTH + + @property + def max_gen_toks(self) -> int: + return 256 + + @property + def batch_size(self): + return self.batch_size_per_gpu + + @property + def device(self): + return self._device + + @property + def rank(self): + return self._rank + + @property + def world_size(self): + return self._world_size + + @property + def tokenizer_name(self) -> str: + return self.tokenizer.name_or_path.replace("/", "__") + + def _get_backend( + self, + config: Union[transformers.PretrainedConfig, transformers.AutoConfig], + backend: Literal["default", "causal", "seq2seq"] = "default", + trust_remote_code: Optional[bool] = False, + ) -> None: + """ + Helper method during initialization. + Determines the backend ("causal" (decoder-only) or "seq2seq" (encoder-decoder)) model type to be used. + sets `self.AUTO_MODEL_CLASS` appropriately if not already set. + + **If not calling HFLM.__init__() or HFLM._get_backend() within a subclass of HFLM, + user must set `self.backend` to be either "causal" or "seq2seq" manually!** + """ + + assert backend in ["default", "causal", "seq2seq"] + + if backend != "default": + # if we've settled on non-default backend, use that manually + if backend == "causal": + self.backend = backend + elif backend == "seq2seq": + self.backend = backend + eval_logger.info( + f"Overrode HF model backend type, and using type '{self.backend}'" + ) + else: + # determine and use the default HF backend for this model, based on its config + metadata. + if ( + getattr(config, "model_type") + in MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES + ): + # first check if model type is listed under seq2seq models, since some + # models like MBart are listed in both seq2seq and causal mistakenly in HF transformers. + # these special cases should be treated as seq2seq models. + self.backend = "seq2seq" + eval_logger.debug(f"Using model type '{self.backend}'") + elif ( + getattr(self.config, "model_type") in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES + ): + self.backend = "causal" + eval_logger.debug(f"Using model type '{self.backend}'") + else: + if not trust_remote_code: + eval_logger.warning( + "HF model type is neither marked as CausalLM or Seq2SeqLM. \ + This is expected if your model requires `trust_remote_code=True` but may be an error otherwise." + "Setting backend to causal" + ) + # if model type is neither in HF transformers causal or seq2seq model registries + # then we default to assuming AutoModelForCausalLM + self.backend = "causal" + eval_logger.info( + f"Model type cannot be determined. Using default model type '{self.backend}'" + ) + + if self.AUTO_MODEL_CLASS is None: + if self.backend == "causal": + self.AUTO_MODEL_CLASS = transformers.AutoModelForCausalLM + elif self.backend == "seq2seq": + self.AUTO_MODEL_CLASS = transformers.AutoModelForSeq2SeqLM + + def _get_config( + self, + pretrained: str, + revision: str = "main", + trust_remote_code: bool = False, + gguf_file: Optional[str] = None, + subfolder: str = "", + ) -> None: + """Return the model config for HuggingFace models""" + self._config = transformers.AutoConfig.from_pretrained( + pretrained, + revision=revision, + trust_remote_code=trust_remote_code, + gguf_file=gguf_file, + subfolder=subfolder, + ) + + def _create_model( + self, + pretrained: str, + revision: Optional[str] = "main", + dtype: Optional[Union[str, torch.dtype]] = "auto", + trust_remote_code: Optional[bool] = False, + # arguments used for splitting a model across GPUs naively. + # only used if `parallelize=True`. + # (accelerate naive PP (device_map) options) + parallelize: Optional[bool] = False, + gpus: Optional[int] = None, + max_memory_per_gpu: Optional[Union[int, str]] = None, + max_cpu_memory: Optional[Union[int, str]] = None, + offload_folder: Optional[str] = "./offload", + # PEFT, delta weights and quantization options + peft: Optional[str] = None, + delta: Optional[str] = None, + autogptq: Optional[Union[bool, str]] = False, + gptqmodel: Optional[bool] = False, + gguf_file: Optional[str] = None, + quantization_config: Optional[Dict[str, Any]] = None, + subfolder: str = "", + **kwargs, + ) -> None: + """ + Initializes an HF or HF-compatible PreTrainedModel from scratch + inside HFLM, using the kwargs passed into self.__init__(). + + Also handles functionality such as AutoGPTQ usage and PEFT wrapping. + + For future similar extensions to AutoGPTQ that are not core to HF's ecosystem, + (such as PyTorch models that are nearly, but not quite, fully mirroring + HF's public interface relied on in this HFLM class) + please consider subclassing HFLM and overriding this and other methods as needed. + """ + + model_kwargs = kwargs if kwargs else {} + + model_kwargs.update( + self._get_accelerate_args( + parallelize=parallelize, + device_map=kwargs.get("device_map", None), + max_memory_per_gpu=max_memory_per_gpu, + max_cpu_memory=max_cpu_memory, + offload_folder=offload_folder, + gpus=gpus, + ) + ) + + if not autogptq and not gptqmodel: + if model_kwargs.get("load_in_4bit", None): + assert transformers.__version__ >= "4.30.0", ( + "load_in_4bit requires transformers >= 4.30.0" + ) + if transformers.__version__ >= "4.30.0": + if model_kwargs.get("load_in_4bit", None): + if model_kwargs.get("bnb_4bit_compute_dtype", None): + model_kwargs["bnb_4bit_compute_dtype"] = get_dtype( + model_kwargs["bnb_4bit_compute_dtype"] + ) + + self._model = self.AUTO_MODEL_CLASS.from_pretrained( + pretrained, + revision=revision, + torch_dtype=get_dtype(dtype), + trust_remote_code=trust_remote_code, + gguf_file=gguf_file, + quantization_config=quantization_config, + subfolder=subfolder, + **model_kwargs, + ) + else: + if autogptq and gptqmodel: + raise ValueError( + "Cannot use both 'autogptq' and 'gptqmodel' options at the same time." + ) + + if autogptq: + try: + from auto_gptq import AutoGPTQForCausalLM + except ModuleNotFoundError as exception: + raise type(exception)( + "Tried to load auto_gptq, but auto-gptq is not installed ", + "please install auto-gptq via pip install lm-eval[gptq] or pip install -e .[gptq]", + ) + + self._model = AutoGPTQForCausalLM.from_quantized( + pretrained, + trust_remote_code=trust_remote_code, + model_basename=None if autogptq is True else Path(autogptq).stem, + use_safetensors=True + if autogptq is True + else autogptq.endswith(".safetensors"), + **model_kwargs, + ) + + if gptqmodel: + try: + from gptqmodel import GPTQModel + except ModuleNotFoundError as exception: + raise type(exception)( + "Tried to load gptqmodel, but gptqmodel is not installed ", + "please install gptqmodel via `pip install gptqmodel --no-build-isolation` or `pip install lm-eval[gptqmodel] --no-build-isolation`", + ) + + self._model = GPTQModel.from_quantized( + pretrained, trust_remote_code=trust_remote_code, **model_kwargs + ) + + if peft and delta: + raise ValueError( + "Cannot use both 'peft' and 'delta' options at the same time." + ) + + if peft: + if model_kwargs.get("load_in_4bit", None): + if version.parse(PEFT_VERSION) < version.parse("0.4.0"): + raise AssertionError("load_in_4bit requires peft >= 0.4.0") + if self._model.config.vocab_size != len(self.tokenizer): + # resize model for LoRAs with added tokens + eval_logger.info( + f"Model config indicates vocab_size='{self._model.config.vocab_size}', but found tokenizer with vocab size '{len(self.tokenizer)}'. Resizing model embedding layer..." + ) + self._model.resize_token_embeddings(len(self.tokenizer)) + self._model = PeftModel.from_pretrained( + self._model, peft, revision=revision + ) + elif delta: + if autogptq: + eval_logger.warning( + "Delta weights might trigger unexpected behavior when used with AutoGPTQ." + ) + _model_delta = self.AUTO_MODEL_CLASS.from_pretrained( + delta, + revision=revision, + torch_dtype=get_dtype(dtype), + trust_remote_code=trust_remote_code, + **model_kwargs, + ) + for name, param in self._model.state_dict().items(): + try: + param.data += _model_delta.state_dict()[name] + except KeyError: + raise KeyError(f"Delta model is missing weights for layer: {name}") + except Exception as e: + raise RuntimeError( + f"Failed to add delta weights to layer {name}. Error: {e}" + ) + + del _model_delta + + return None + + def _create_tokenizer( + self, + pretrained: Union[str, transformers.PreTrainedModel], + tokenizer: Optional[ + Union[ + str, + transformers.PreTrainedTokenizer, + transformers.PreTrainedTokenizerFast, + ] + ], + revision: Optional[str] = "main", + trust_remote_code: Optional[bool] = False, + use_fast_tokenizer: Optional[bool] = True, + gguf_file: Optional[str] = None, + add_bos_token: Optional[bool] = False, + subfolder: Optional[str] = "", + ) -> None: + """ + Helper method during initialization. + + Create a tokenizer object corresponding to the correct + tokenizer for value of `pretrained`, or use the pre-initialized tokenizer passed. + """ + kwargs = { + "revision": revision, + "trust_remote_code": trust_remote_code, + } + + # gguf format embeds tokenizer and is not compatible with hf tokenizer `use_fast` param + if gguf_file is not None: + kwargs["gguf_file"] = gguf_file + else: + kwargs["use_fast"] = use_fast_tokenizer + + if add_bos_token: + kwargs["add_bos_token"] = True + + if subfolder: + kwargs["subfolder"] = subfolder + + if tokenizer: + if isinstance(tokenizer, str): + self.tokenizer = transformers.AutoTokenizer.from_pretrained( + tokenizer, **kwargs + ) + else: + assert isinstance( + tokenizer, transformers.PreTrainedTokenizer + ) or isinstance(tokenizer, transformers.PreTrainedTokenizerFast) + self.tokenizer = tokenizer + else: + # Get tokenizer based on 'pretrained' + if isinstance(pretrained, str): + model_name = pretrained + else: + # get the HF hub name via accessor on model + model_name = self.model.name_or_path + self.tokenizer = transformers.AutoTokenizer.from_pretrained( + model_name, **kwargs + ) + return None + + def _detect_batch_size(self, requests=None, pos: int = 0): + if requests: + _, context_enc, continuation_enc = requests[pos] + max_length = len( + (context_enc + continuation_enc)[-(self.max_length + 1) :][:-1] + ) + max_context_enc = len(context_enc[-(self.max_length + 1) :]) + max_cont_enc = len(continuation_enc[-(self.max_length + 1) :]) + else: + max_length = self.max_length + max_context_enc = max_length + max_cont_enc = max_length + + # if OOM, then halves batch_size and tries again + @find_executable_batch_size(starting_batch_size=self.max_batch_size) + def forward_batch(batch_size): + if self.backend == "seq2seq": + length = max(max_context_enc, max_cont_enc) + batched_conts = torch.ones( + (batch_size, length), device=self.device + ).long() + test_batch = torch.ones((batch_size, length), device=self.device).long() + call_kwargs = { + "attn_mask": test_batch, + "labels": batched_conts, + } + else: + call_kwargs = {} + test_batch = torch.ones( + (batch_size, max_length), device=self.device + ).long() + for _ in range(5): + out = F.log_softmax( # noqa: F841 + self._model_call(test_batch, **call_kwargs), + dim=-1, + dtype=self.softmax_dtype, + ) + + return batch_size + + try: + batch_size = forward_batch() + except RuntimeError as e: + if "No executable batch size found" in str(e): + batch_size = 1 + else: + raise + + if self.world_size > 1: + # if multi-GPU, always take minimum over all selected batch sizes + max_rnk_bs = torch.tensor([batch_size], device=self.device) + gathered = ( + self.accelerator.gather(max_rnk_bs).cpu().detach().numpy().tolist() + ) + batch_size = min(gathered) + clear_torch_cache() + return batch_size + + clear_torch_cache() + return batch_size + + def tok_encode( + self, string: str, left_truncate_len=None, add_special_tokens=None + ) -> List[int]: + """ """ + # default for None - empty dict, use predefined tokenizer param + # used for all models except for CausalLM or predefined value + special_tokens_kwargs = {} + + # by default for CausalLM - false or self.add_bos_token is set + if add_special_tokens is None: + if self.backend == "causal": + special_tokens_kwargs = { + "add_special_tokens": False or self.add_bos_token + } + # otherwise the method explicitly defines the value + else: + special_tokens_kwargs = {"add_special_tokens": add_special_tokens} + + encoding = self.tokenizer.encode(string, **special_tokens_kwargs) + + # left-truncate the encoded context to be at most `left_truncate_len` tokens long + if left_truncate_len: + encoding = encoding[-left_truncate_len:] + + return encoding + + def tok_batch_encode( + self, + strings: List[str], + padding_side: str = "left", + left_truncate_len: int = None, + truncation: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # encode a batch of strings. converts to tensors and pads automatically, unlike tok_encode. + old_padding_side = self.tokenizer.padding_side + self.tokenizer.padding_side = padding_side + + add_special_tokens = {} + if self.backend == "causal": + add_special_tokens = {"add_special_tokens": False or self.add_bos_token} + + encoding = self.tokenizer( + strings, + truncation=truncation, + padding="longest", + return_tensors="pt", + **add_special_tokens, + ) + if left_truncate_len: + original_lengths = encoding["input_ids"].size(1) + if original_lengths > left_truncate_len: + eval_logger.warn( + f"Left truncation applied. Original sequence length was {original_lengths}, " + f"truncating to last {left_truncate_len} tokens. Some content will be lost.", + ) + encoding["input_ids"] = encoding["input_ids"][:, -left_truncate_len:] + encoding["attention_mask"] = encoding["attention_mask"][ + :, -left_truncate_len: + ] + self.tokenizer.padding_side = old_padding_side + + return encoding["input_ids"], encoding["attention_mask"] + + def tok_decode(self, tokens, skip_special_tokens=True): + return self.tokenizer.decode(tokens, skip_special_tokens=skip_special_tokens) + + def _model_call(self, inps, attn_mask=None, labels=None): + """ + :param inps: torch.Tensor + A torch tensor of shape [batch, (sequence_ctx + sequence_cont)] or of shape + [batch, sequence_ctx]. the size of sequence may vary from call to call + :param attn_mask: torch.Tensor, optional + A torch tensor of shape [batch, (sequence_ctx + sequence_cont)]. Only passed + (and must be passed) if self.AUTO_MODEL_CLASS is transformers.AutoModelForSeq2SeqLM + :param labels: torch.Tensor, optional + A torch tensor of shape [batch, (sequence_ctx + sequence_cont)]. Only passed + (and must be passed) if self.AUTO_MODEL_CLASS is transformers.AutoModelForSeq2SeqLM + :return + A torch tensor of shape [batch, sequence, vocab] with the + logits returned from the model's decoder + """ + with torch.no_grad(): + if attn_mask is not None or labels is not None: + assert attn_mask is not None and labels is not None + assert self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM + return self.model( + input_ids=inps, attention_mask=attn_mask, labels=labels + ).logits + else: + assert self.AUTO_MODEL_CLASS in ( + transformers.AutoModelForCausalLM, + transformers.AutoModelForVision2Seq, + ) + return self.model(inps).logits + + def _model_generate(self, context, max_length, stop, **generation_kwargs): + # temperature = 0.0 if not set + # if do_sample is false and temp==0.0: + # remove temperature, as do_sample=False takes care of this + # and we don't want a warning from HF + generation_kwargs["temperature"] = generation_kwargs.get("temperature", 0.0) + do_sample = generation_kwargs.get("do_sample", None) + + # The temperature has to be a strictly positive float -- if it is 0.0, use greedy decoding strategies + if generation_kwargs.get("temperature") == 0.0 and do_sample is None: + generation_kwargs["do_sample"] = do_sample = False + + if do_sample is False and generation_kwargs.get("temperature") == 0.0: + generation_kwargs.pop("temperature") + # build stopping criteria + stopping_criteria = stop_sequences_criteria( + self.tokenizer, stop, context.shape[1], context.shape[0] + ) + return self.model.generate( + input_ids=context, + max_length=max_length, + stopping_criteria=stopping_criteria, + pad_token_id=self.tokenizer.pad_token_id, + use_cache=True, + **generation_kwargs, + ) + + def _select_cont_toks( + self, logits: torch.Tensor, contlen: int = None, inplen: int = None + ) -> torch.Tensor: + if self.backend == "causal": + assert contlen and inplen, ( + "Must pass input len and cont. len to select scored logits for causal LM" + ) + # discard right-padding. + # also discard the input/context tokens. we'll only score continuations. + logits = logits[inplen - contlen : inplen] + elif self.backend == "seq2seq": + assert contlen and not inplen, ( + "Selecting scored logits for Seq2SeqLM requires only cont. len" + ) + # only discard right-padding. + # the logits input to this fn only contain decoder-side tokens. + logits = logits[:contlen] + + return logits + + def loglikelihood_rolling( + self, requests: List[Instance], disable_tqdm: bool = False + ) -> List[float]: + adaptive_batch_size = None + if self.batch_size == "auto": + # using rolling window with maximum context + print("Passed argument batch_size = auto. Detecting largest batch size") + batch_size = self._detect_batch_size() + print(f"Determined Largest batch size: {batch_size}") + adaptive_batch_size = batch_size + + # First, collect all windows from all requests + all_windows = [] # List of (request_idx, window) tuples + request_window_counts = [] # Track number of windows per request + + for req_idx, (string,) in enumerate( + tqdm( + [req.args for req in requests], + disable=(disable_tqdm or (self.rank != 0)), + ) + ): + rolling_token_windows: List[Tuple[List[int], List[int]]] = list( + map( + utils.make_disjoint_window, + utils.get_rolling_token_windows( + token_list=self.tok_encode(string), + prefix_token=self.prefix_token_id, + max_seq_len=self.max_length, + context_len=1, + ), + ) + ) + + # TODO: Right now, we pass single EOT token to the Encoder and the full context to the decoder, in seq2seq case + windows = [(None,) + x for x in rolling_token_windows] + + # Store windows with their request index + all_windows.extend((req_idx, window) for window in windows) + request_window_counts.append(len(windows)) + + # Handle distributed case padding + pad_amnt = 0 + if self.world_size > 1: + mytensor = torch.tensor(len(all_windows), device=self.device) + gathered = self.accelerator.gather(mytensor).cpu().detach().numpy().tolist() + pad_amnt = max(gathered) - gathered[self.rank] + if pad_amnt > 0: + all_windows += pad_amnt * [all_windows[0]] + + all_nlls = [] + batch_size = adaptive_batch_size or self.batch_size + for i in range(0, len(all_windows), batch_size): + batch = all_windows[i : i + batch_size] + # Extract just the windows for processing, keeping track of request indices + batch_indices, batch_windows = zip(*batch) + + batch_nlls = self._loglikelihood_tokens( + requests=batch_windows, + disable_tqdm=False, + override_bs=len(batch_windows), + ) + # Store results with their request indices + all_nlls.extend(zip(batch_indices, batch_nlls)) + + # Remove padding if necessary + if (self.world_size > 1) and (pad_amnt > 0): + all_nlls = all_nlls[:-pad_amnt] + + # Reconstruct per-request loglikelihoods + loglikelihoods = [] + current_idx = 0 + for window_count in request_window_counts: + # Get all nlls for this request + request_nlls = all_nlls[current_idx : current_idx + window_count] + # Sum up the nlls for this request (discarding is_greedy) + request_total = sum(nll[0] for _, nll in request_nlls) + loglikelihoods.append(request_total) + current_idx += window_count + + string = requests[len(loglikelihoods) - 1].args[0] + self.cache_hook.add_partial( + "loglikelihood_rolling", (string,), request_total + ) + + return loglikelihoods + + def _batch_scheduler(self, pos, n_reordered_requests): + sched = pos // int(len(n_reordered_requests) / self.batch_schedule) + if sched in self.batch_sizes: + return self.batch_sizes[sched] + if (len(self.batch_sizes) > 1) and ( + self.batch_sizes[sched - 1] == self.max_batch_size + ): + # if previous batch size is already maximal, skip recomputation + self.batch_sizes[sched] = self.max_batch_size + return self.batch_sizes[sched] + print( + f"Passed argument batch_size = auto:{self.batch_schedule}. Detecting largest batch size" + ) + self.batch_sizes[sched] = self._detect_batch_size(n_reordered_requests, pos) + print(f"Determined largest batch size: {self.batch_sizes[sched]}") + return self.batch_sizes[sched] + + def _loglikelihood_tokens( + self, + requests: List[Tuple[Tuple[str, str], List[int], List[int]]], + disable_tqdm: bool = False, + override_bs: int = None, + ) -> List[Tuple[float, bool]]: + # TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context + res = [] + + def _collate(req: Tuple[Tuple[str, str], List[int], List[int]]): + """Defines the key for the sorted method""" + # the negative sign on len(toks) sorts descending - this has a few advantages: + # - time estimates will always be over not underestimates, which is more useful for planning + # - to know the size of a batch when going through the list, you know the first one is always the batch + # padded context length. this is useful to simplify the batching logic and more importantly to make + # automatic adaptive batches much much easier to implement + # - any OOMs will happen right away rather than near the end + + toks = req[1] + req[2] + return -len(toks), tuple(toks) + + def _lookup_one_token_cont(req: Tuple[Tuple[str, str], List[int], List[int]]): + """Defines the key to group and lookup one-token continuations""" + # Use with group_by="contexts" (optional)" + # allows for the creation of a lookup, so we can reuse logits in case of one-token continuations. + # speeds up some multiple-choice tasks proportionally to the number of choices. + # groups requests by context+continuation[:-1] and infer on one request/group. + return req[-2] + req[-1][:-1] + + re_ord = Collator( + requests, + sort_fn=_collate, + group_by="contexts" + if self.backend == "causal" and self.logits_cache + else None, + group_fn=_lookup_one_token_cont, + ) + + # automatic (variable) batch size detection for vectorization + # pull longest context sample from request + n_reordered_requests = len(re_ord) + batch_size = ( + self.batch_size + if self.batch_size != "auto" + else override_bs + if override_bs is not None + else 0 + ) + batch_fn = ( + self._batch_scheduler + if self.batch_size == "auto" + and n_reordered_requests > 0 + and not override_bs + else None + ) + + chunks = re_ord.get_batched(n=batch_size, batch_fn=batch_fn) + pbar = tqdm( + total=len(requests), + disable=(disable_tqdm or (self.rank != 0)), + desc="Running loglikelihood requests", + ) + for chunk in chunks: + inps = [] + cont_toks_list = [] + inplens = [] + + conts = [] + encoder_attns = [] + + padding_len_inp = None + padding_len_cont = None + # because vectorizing is annoying, we first convert each (context, continuation) pair to padded + # tensors, then we pack them together into a batch, call the model, and then pick it all apart + # again because vectorizing is annoying + + for _, context_enc, continuation_enc in chunk: + # sanity check + assert len(context_enc) > 0 + assert len(continuation_enc) > 0 + assert len(continuation_enc) <= self.max_length + + # how this all works (illustrated on a causal decoder-only setup): + # CTX CONT + # inp 0 1 2 3|4 5 6 7 8 9 <- last token is deleted by inp[:, :-1] + # model \ \ + # logits 1 2 3|4 5 6 7 8 9 <- the ctx half gets tossed out by the + # cont_toks 4 5 6 7 8 9 [:, -len(continuation_enc):, :self.vocab_size] slice + + # when too long to fit in context, truncate from the left + if self.backend == "causal": + total_length = len(context_enc) + len(continuation_enc) + if total_length > self.max_length + 1: + eval_logger.warning( + f"Combined length of context ({len(context_enc)}) and continuation ({len(continuation_enc)}) " + f"exceeds model's maximum length ({self.max_length}). " + f"Truncating {total_length - self.max_length + 1} tokens from the left." + ) + inp = torch.tensor( + (context_enc + continuation_enc)[-(self.max_length + 1) :][:-1], + dtype=torch.long, + device=self.device, + ) + (inplen,) = inp.shape + elif self.backend == "seq2seq": + inp = torch.tensor( + (context_enc)[-self.max_length :], + dtype=torch.long, + device=self.device, + ) + (inplen,) = inp.shape + + # build encoder attn masks + encoder_attns.append(torch.ones_like(inp)) + + cont = torch.tensor( + (continuation_enc)[-self.max_length :], + # TODO: left-shift these? + # TODO: our code assumes we never end up truncating conts for either model type + dtype=torch.long, + device=self.device, + ) + (contlen,) = cont.shape + + conts.append(cont) + + padding_len_cont = ( + max(padding_len_cont, contlen) + if padding_len_cont is not None + else contlen + ) + + padding_len_inp = ( + max(padding_len_inp, inplen) + if padding_len_inp is not None + else inplen + ) + + inps.append(inp) # [1, inp_length] + cont_toks_list.append(continuation_enc) + inplens.append(inplen) + + # create encoder attn mask and batched conts, if seq2seq + call_kwargs = {} + if self.backend == "causal": + batched_inps = pad_and_concat( + padding_len_inp, inps, padding_side="right" + ) # [batch, padding_len_inp] + elif self.backend == "seq2seq": + # TODO: left-pad encoder inps and mask? + batched_inps = pad_and_concat( + padding_len_inp, inps + ) # [batch, padding_len_inp] + batched_conts = pad_and_concat( + padding_len_cont, conts + ) # [batch, padding_len_cont] + batched_encoder_mask = pad_and_concat( + padding_len_inp, encoder_attns + ) # [batch, padding_len_inp] + call_kwargs = { + "attn_mask": batched_encoder_mask, + "labels": batched_conts, + } + + multi_logits = F.log_softmax( + self._model_call(batched_inps, **call_kwargs), + dim=-1, + dtype=self.softmax_dtype, + ) # [batch, padding_length (inp or cont), vocab] + + for (request_str, ctx_tokens, _), logits, inplen, cont_toks in zip( + chunk, multi_logits, inplens, cont_toks_list + ): + # Slice to original seq length + contlen = len(cont_toks) + # take only logits in the continuation + # (discard context toks if decoder-only ; discard right-padding) + # also discards + checks for "virtual tokens" in the causal LM's input window + # from prompt/prefix tuning tokens, if applicable + ctx_len = ( + inplen + (logits.shape[0] - padding_len_inp) + if self.backend == "causal" + else None + ) + logits = self._select_cont_toks(logits, contlen=contlen, inplen=ctx_len) + logits = logits.unsqueeze(0) # [1, seq, vocab] + + # Check if per-token argmax is exactly equal to continuation + greedy_tokens = logits.argmax(dim=-1) + + # check for one-token continuation cache hits. + # noop in case group_by != "contexts" or no cache hit and returns the + # original args. Otherwise, expands the logits batch dimension and yields each + # batch along with matching continuation tokens and prompt strings. + # logits -> [1, seq, vocab] + for request_str, cont_toks, logits in re_ord.get_cache( + req_str=request_str, + cxt_toks=ctx_tokens, + cont_toks=cont_toks, + logits=logits, + ): + cont_toks = torch.tensor( + cont_toks, dtype=torch.long, device=self.device + ).unsqueeze(0) # [1, seq] + # Use trailing slice [-cont_toks.shape[1]:] to handle variable length cont_len (but same ctx+cont[:-1]). + # i.e. continuations can be sliced at diff points. Collator ensures we have sufficient greedy_tokens + # by choosing key with longest cont if group_by="contexts". + max_equal = ( + greedy_tokens[:, -cont_toks.shape[1] :] == cont_toks + ).all() + + # Obtain log-probs at the corresponding continuation token indices + # last_token_slice = logits[:, -1, :].squeeze(0).tolist() + logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze( + -1 + ) # [1, seq] + + # Answer: (log prob, is-exact-match) + answer = (float(logits.sum()), bool(max_equal)) + + res.append(answer) + + if request_str is not None: + # special case: loglikelihood_rolling produces a number of loglikelihood requests + # all with cache key None. instead do add_partial on the per-example level + # in the loglikelihood_rolling() function for those. + self.cache_hook.add_partial( + "loglikelihood", request_str, answer + ) + pbar.update(1) + + pbar.close() + + return re_ord.get_original(res) + + def generate_until( + self, requests: List[Instance], disable_tqdm: bool = False + ) -> List[str]: + res = [] + + def _collate(req: Tuple[str, dict]): + """Defines the key for the sorted method""" + # the negative sign on len(toks) sorts descending - this has a few advantages: + # - time estimates will always be over not underestimates, which is more useful for planning + # - to know the size of a batch when going through the list, you know the first one is always the batch + # padded context length. this is useful to simplify the batching logic and more importantly to make + # automatic adaptive batches much much easier to implement + # - any OOMs will happen right away rather than near the end + toks = self.tok_encode(req[0]) + return -len(toks), req[0] + + pbar = tqdm( + total=len(requests), + disable=(disable_tqdm or (self.rank != 0)), + desc="Running generate_until requests", + ) + adaptive_batch_size = None + if self.batch_size == "auto": + # using rolling window with maximum context + print("Passed argument batch_size = auto. Detecting largest batch size") + batch_size = self._detect_batch_size() + print(f"Determined Largest batch size: {batch_size}") + adaptive_batch_size = batch_size + # for each different set of kwargs, we execute all requests, by batch. + batch_size = ( + self.batch_size + if self.batch_size != "auto" + else adaptive_batch_size + if adaptive_batch_size is not None + else 0 + ) + batch_fn = ( + self._batch_scheduler + if self.batch_size == "auto" and not adaptive_batch_size + else None + ) + + # we group requests by their generation_kwargs, + # so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling + # in the same batch. + # group_fn=lambda x: x[1] -> x=(context, gen_kwargs) + re_ords = Collator( + [reg.args for reg in requests], + sort_fn=_collate, + group_by="gen_kwargs", + group_fn=lambda x: x[1], + ) + chunks = re_ords.get_batched(n=batch_size, batch_fn=batch_fn) + eos = self.tok_decode(self.eot_token_id, skip_special_tokens=False) + for chunk in chunks: + contexts, all_gen_kwargs = zip(*chunk) + # we assume all gen kwargs in the batch are the same + # this is safe to assume because the `grouper` object ensures it. + gen_kwargs = all_gen_kwargs[0] + # unpack our keyword arguments. + if isinstance(gen_kwargs, dict): + kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1 + # add EOS token to stop sequences + until = handle_stop_sequences(kwargs.pop("until", None), eos=eos) + else: + raise ValueError( + f"Expected `kwargs` to be of type `dict` but got {type(gen_kwargs)}" + ) + if "max_gen_toks" in kwargs.keys(): + max_gen_toks = kwargs.pop("max_gen_toks") + else: + max_gen_toks = self.max_gen_toks + + # set the max length in tokens of inputs ("context_enc") + if self.backend == "causal": + # max len for inputs = max length, minus room to generate the max new tokens + max_ctx_len = self.max_length - max_gen_toks + assert max_ctx_len > 0, ( + f"Invalid configuration: requested max tokens to generate ({max_gen_toks}) must be less than model's maximum sequence length ({self.max_length})." + ) + elif self.backend == "seq2seq": + # max len for inputs = encoder's whole max_length + max_ctx_len = self.max_length + + # encode, pad, and truncate contexts for this batch + context_enc, attn_masks = self.tok_batch_encode( + contexts, + left_truncate_len=max_ctx_len, + truncation=self.truncation, + ) + context_enc = context_enc.to(self.device) + attn_masks = attn_masks.to(self.device) + + if "max_length" not in kwargs: + kwargs["max_length"] = context_enc.shape[1] + max_gen_toks + + # perform batched generation + cont = self._model_generate( + context=context_enc, + attention_mask=attn_masks, + stop=until, + **kwargs, + ) + + cont_toks_list = cont.tolist() + for cont_toks, context in zip(cont_toks_list, contexts): + # discard context + left-padding toks if using causal decoder-only LM + if self.backend == "causal": + cont_toks = cont_toks[context_enc.shape[1] :] + + s = self.tok_decode(cont_toks) + + # use secondary stop seqs to cut off should-have-been-stopped content post-hoc + for term in until: + if len(term) > 0: + # ignore '' separator, + # for seq2seq case where self.tok_decode(self.eot_token_id) = '' + s = s.split(term)[0] + + res.append(s) + + self.cache_hook.add_partial("generate_until", (context, gen_kwargs), s) + pbar.update(1) + # reorder this group of results back to original unsorted form + res = re_ords.get_original(res) + + pbar.close() + + return res + + def apply_chat_template( + self, chat_history: List[Dict[str, str]], add_generation_prompt: bool = True + ) -> str: + """ + Method to apply a chat template to a list of chat history between user and model. + """ + try: + chat_templated = self.tokenizer.apply_chat_template( + chat_history, + tokenize=False, + add_generation_prompt=add_generation_prompt, + continue_final_message=not add_generation_prompt, + ) + except jinja2.exceptions.TemplateError: + eval_logger.warning( + "Failed to apply chat template. removing the system role in chat history." + ) + chat_history = [msg for msg in chat_history if msg["role"] != "system"] + chat_templated = self.tokenizer.apply_chat_template( + chat_history, + tokenize=False, + add_generation_prompt=add_generation_prompt, + continue_final_message=not add_generation_prompt, + ) + + return chat_templated + + def get_model_info(self) -> dict: + """ + Method to get Hugging Face model information for experiment reproducibility. + """ + + def get_model_num_params(model) -> int: + if hasattr(model, "num_parameters"): + return model.num_parameters() + if hasattr(model, "parameters"): + return sum(p.numel() for p in model.parameters()) + else: + return -1 + + def get_model_dtype(model) -> str: + if hasattr(model, "dtype"): + return model.dtype + else: + return "" + + def get_model_sha(pretrained: str, revision: str) -> str: + try: + model_info = HfApi().model_info(repo_id=pretrained, revision=revision) + return model_info.sha + except Exception as e: + eval_logger.debug( + f"Failed to get model SHA for {pretrained} at revision {revision}. Error: {e}" + ) + return "" + + model_info = { + "model_num_parameters": get_model_num_params(self._model), + "model_dtype": get_model_dtype(self._model), + "model_revision": self.revision, + "model_sha": get_model_sha(self.pretrained, self.revision), + } + if self.peft: + model_info["peft_sha"] = get_model_sha(self.peft, self.revision) + if self.delta: + model_info["delta_sha"] = get_model_sha(self.delta, self.revision) + return model_info diff --git a/Prism/LLaDA/LLaDA_Baseline/dllm_eval/models/utils.py b/Prism/LLaDA/LLaDA_Baseline/dllm_eval/models/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e17fa224b22fbbef442c94e13d4f7c237d3c647d --- /dev/null +++ b/Prism/LLaDA/LLaDA_Baseline/dllm_eval/models/utils.py @@ -0,0 +1,854 @@ +import collections +import fnmatch +import gc +import itertools +import logging +import time +from functools import wraps +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + Iterable, + Iterator, + List, + Literal, + Optional, + Tuple, + Type, + Union, +) + +import torch +import transformers + + +eval_logger = logging.getLogger(__name__) + + +if TYPE_CHECKING: + from PIL import Image + from transformers import PreTrainedTokenizerBase + from transformers.configuration_utils import PretrainedConfig + + +def chunks(iter, n: int = 0, fn=None): + """ + Divides an iterable into chunks of specified size or based on a given function. + Useful for batching + + Parameters: + - iter: The input iterable to be divided into chunks. + - n: An integer representing the size of each chunk. Default is 0. + - fn: A function that takes the current index and the iterable as arguments and returns the size of the chunk. Default is None. + + Returns: + An iterator that yields chunks of the input iterable. + + Example usage: + ``` + data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + for chunk in chunks(data, 3): + print(chunk) + ``` + Output: + ``` + [1, 2, 3] + [4, 5, 6] + [7, 8, 9] + [10] + ``` + """ + arr = [] + for i, x in enumerate(iter): + arr.append(x) + if len(arr) == (fn(i, iter) if fn else n): + yield arr + arr = [] + + if arr: + yield arr + + +class MultiChoice: + def __init__(self, choices) -> None: + self.choices = choices + + # Simple wildcard support (linux filename patterns) + def __contains__(self, values) -> bool: + for value in values.split(","): + if len(fnmatch.filter(self.choices, value)) == 0: + eval_logger.info("Available tasks to choose:") + for choice in self.choices: + eval_logger.info(f" - {choice}") + raise ValueError("'{}' is not in task list".format(value)) + return True + + def __iter__(self) -> Iterator: + for choice in self.choices: + yield choice + + +class Grouper: + """ + takes an array `arr` and function `fn` and returns a dictionary + with keys fn(ob) for each ob in `arr` and with values `self.arr[key]` a list of all + objects in `arr` satisfying `key == fn(ob)`. + """ + + def __init__(self, arr, fn) -> None: + # self.orig_arr = arr + self.size = len(arr) + arr = list(enumerate(arr)) + + def group_return_dict(arr, fn): + res = collections.defaultdict(list) + + for ob in arr: + res[fn(ob)].append(ob) + return res + + arr = group_return_dict(arr, lambda x: fn(x[1])) + + # self.arr has format Dict[Tuple[int, ]] + self.arr = arr + self._grouped = None + + def get_grouped(self): + # return the contents but not indices for our grouped dict. + if self._grouped: + return self._grouped + grouped = {} + for key in self.arr.keys(): + # drop the index from each element of self.arr + grouped[key] = [y[1] for y in self.arr[key]] + self._grouped = grouped + return grouped + + def get_original(self, grouped_dict): + # take in a grouped dictionary with e.g. results for each key listed + # in the same order as the instances in `self.arr`, and + # return the results in the same (single list) order as `self.orig_arr`. + res = [None] * self.size + cov = [False] * self.size + # orig = [None] * self.size + + assert grouped_dict.keys() == self.arr.keys() + + for key in grouped_dict.keys(): + for (ind, _), v in zip(self.arr[key], grouped_dict[key]): + res[ind] = v + cov[ind] = True + # orig[ind] = _ + + assert all(cov) + # assert orig == self.orig_arr + + return res + + +def pad_and_concat( + max_length: int, + tensors: List[torch.Tensor], + padding_side: Literal["right", "left"] = "right", +): + """ + Method for padding a list of tensors given the maximum tensor + length in the batch. Used for batching inputs and continuations in + seq2seq models. + """ + assert padding_side == "left" or padding_side == "right", ( + f"Unrecognized padding type: '{padding_side}' not 'left' or 'right'" + ) + + for i, tensor in enumerate(tensors): + if len(tensor.shape) == 2: + tensor = tensor.squeeze(0) # squeeze, in case passed [1, seq] size + tensor_len = tensor.shape[0] + if tensor_len < max_length: + if padding_side == "right": + # right-pad + tensors[i] = torch.cat( + [ + tensor, # [seq] + torch.zeros( + max_length - tensor_len, + dtype=torch.long, + device=tensor.device, + ), # [padding_length - seq] + ], + dim=0, + ).unsqueeze(0) + else: + # left-pad + tensors[i] = torch.cat( + [ + torch.zeros( + max_length - tensor_len, + dtype=torch.long, + device=tensor.device, + ), # [padding_length - seq] + tensor, # [seq] + ], + dim=0, + ).unsqueeze(0) + else: + tensors[i] = tensor.unsqueeze(0) + + return torch.cat(tensors, dim=0) + + +def clear_torch_cache() -> None: + gc.collect() + torch.cuda.empty_cache() + + +def get_dtype(dtype: Union[str, torch.dtype]) -> torch.dtype: + """Converts `dtype` from `str` to torch.dtype when possible. Does not use an instantiated HF AutoConfig""" + if isinstance(dtype, str) and dtype != "auto": + # Convert `str` args torch dtype: `float16` -> `torch.float16` + _torch_dtype = getattr(torch, dtype) + else: + _torch_dtype = dtype + return _torch_dtype + + +class MultiTokenEOSCriteria(transformers.StoppingCriteria): + """Criteria to stop on the specified multi-token sequence.""" + + def __init__( + self, + sequence: str, + tokenizer: transformers.PreTrainedTokenizer, + initial_decoder_input_length: int, + batch_size: int, + ) -> None: + self.initial_decoder_input_length = initial_decoder_input_length + self.done_tracker = [False] * batch_size + self.sequence = sequence + self.sequence_ids = tokenizer.encode(sequence, add_special_tokens=False) + # print(sequence, self.sequence_ids) + # we look back for 2 more tokens than it takes to encode our stop sequence + # because tokenizers suck, and a model might generate `['\n', '\n']` but our `sequence` is `['\n\n']` + # and we don't want to mistakenly not stop a generation because our + # (string) stop sequence was output in a different tokenization + + # NOTE: there is a minor danger that this will end up looking back 2 tokens into the past, into the inputs to the model, + # and stopping generation immediately as a result. With only 2 extra tokens of lookback, this risk is minimized + # Additionally, in lookback_ids_batch we should prevent ever looking back into the inputs as described. + self.sequence_id_len = len(self.sequence_ids) + 2 + self.tokenizer = tokenizer + + def __call__(self, input_ids, scores, **kwargs) -> bool: + # For efficiency, we compare the last n tokens where n is the number of tokens in the stop_sequence + lookback_ids_batch = input_ids[:, self.initial_decoder_input_length :] + + lookback_ids_batch = lookback_ids_batch[:, -self.sequence_id_len :] + + lookback_tokens_batch = self.tokenizer.batch_decode(lookback_ids_batch) + + for i, done in enumerate(self.done_tracker): + if not done: + self.done_tracker[i] = self.sequence in lookback_tokens_batch[i] + return False not in self.done_tracker + + +def stop_sequences_criteria( + tokenizer: transformers.PreTrainedTokenizer, + stop_sequences: List[str], + initial_decoder_input_length: int, + batch_size: int, +) -> transformers.StoppingCriteriaList: + return transformers.StoppingCriteriaList( + [ + *[ + MultiTokenEOSCriteria( + sequence, tokenizer, initial_decoder_input_length, batch_size + ) + for sequence in stop_sequences + ], + ] + ) + + +def undistribute(iterable): + """ + Undoes https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.distribute . + + Re-interleaves results that have been split using more_itertools.distribute: + >>> group_1, group_2 = distribute(2, [1, 2, 3, 4, 5, 6]) + >>> list(group_1) + [1, 3, 5] + >>> list(group_2) + [2, 4, 6] + >>> undistribute([group_1, group_2]) + [1, 2, 3, 4, 5, 6] + + Handles non-uniform component lengths: + + >>> children = distribute(3, [1, 2, 3, 4, 5, 6, 7]) + >>> [list(c) for c in children] + [[1, 4, 7], [2, 5], [3, 6]] + >>> undistribute(children) + [1, 2, 3, 4, 5, 6, 7] + + Also handles when some iterables are empty: + + >>> children = distribute(5, [1, 2, 3]) + >>> [list(c) for c in children] + [[1], [2], [3], [], []] + >>> undistribute(children) + [1, 2, 3] + + """ + + return [ + x + for x in itertools.chain.from_iterable( + itertools.zip_longest(*[list(x) for x in iterable]) + ) + if x is not None + ] + + +def retry_on_specific_exceptions( + on_exceptions: List[Type[Exception]], + max_retries: Optional[int] = None, + backoff_time: float = 3.0, + backoff_multiplier: float = 1.5, + on_exception_callback: Optional[Callable[[Exception, float], Any]] = None, +): + """Retry on an LLM Provider's rate limit error with exponential backoff + For example, to use for OpenAI, do the following: + ``` + from openai import RateLimitError + + # Recommend specifying max_retries to avoid infinite loops! + @retry_on_specific_exceptions([RateLimitError], max_retries=3) + def completion(...): + # Wrap OpenAI completion function here + ... + ``` + """ + + def decorator(func: Callable): + @wraps(func) + def wrapper(*args, **kwargs): + sleep_time = backoff_time + attempt = 0 + while max_retries is None or attempt < max_retries: + try: + return func(*args, **kwargs) + except tuple(on_exceptions) as e: + if on_exception_callback is not None: + on_exception_callback(e, sleep_time) + time.sleep(sleep_time) + sleep_time *= backoff_multiplier + attempt += 1 + + return wrapper + + return decorator + + +class Collator: + """ + A class for reordering and batching elements of an array. + + This class allows for sorting an array based on a provided sorting function, grouping elements based on a grouping function, and generating batches from the sorted and grouped data. + + Objects of this class have the group_by attribute which determines the method for grouping + the data while batching it. Three options include "gen_kwargs", "contexts", or None: + If group_by == "gen_kwargs" then requests will be grouped by gen_kwargs + If group_by == "contexts" then requests will be grouped by context + cont[:-1] + If None then requests will just be reordered by length descending. + """ + + def __init__( + self, + arr: List, + sort_fn: Callable = lambda x: x, + group_fn: Callable = lambda x: x[1], + group_by: Union[Literal["gen_kwargs", "contexts"], None] = None, + ) -> None: + self._group_by = group_by + # 0 indices are enumerated indices. Apply functions to original arr. + self._sort_fn = lambda x: sort_fn(x[1]) + self._group_fn = lambda x: group_fn(x[1]) + self._reorder_indices: List = [] + self._size = len(arr) + self._arr_with_indices: Union[Dict, Tuple[Tuple[int, Any], ...]] = tuple( + enumerate(arr) + ) # [indices, (arr)] + if self._group_by == "contexts": + self._group_by_context() + elif self._group_by == "gen_kwargs": + self._group_by_index() + + def _group_by_index(self) -> None: + """Group the elements of a list based on their indices.""" + self._arr_with_indices = self.group( + self._arr_with_indices, fn=self._group_fn, group_by="gen_kwargs" + ) + + def _group_by_context(self) -> None: + """Group the array with indices by context.""" + self._arr_with_indices = self.group( + self._arr_with_indices, fn=self._group_fn, group_by="contexts" + ) + + def get_batched(self, n: int = 1, batch_fn: Optional[Callable] = None) -> Iterator: + """ + Generates and yields batches from the reordered array. The method of grouping and batching + depends on the parameter `group_by`. + If `group_by` is set to "gen_kwargs", it will batch the + re-ordered values with same gen_kwargs for each batch. + If `group_by` is "contexts", it caches the requests by context before batching. + If `group_by` is neither "gen_kwargs" nor "contexts", it yields the reordered array + + Parameters: + - n (int): The size of each batch. Defaults to 1. + - batch_fn ([Callable[[int, Iterable], int]] | None): A function to determine the size of + each batch. Optional, defaults to None. + + Returns: + Iterator: An iterator over batches of reordered elements grouped as per the `group_by` + attribute. + + Yields: + List of batched elements according to the `group_by` attribute. + """ + if self._group_by == "gen_kwargs": + for ( + key, + values, + ) in self._arr_with_indices.items(): # type: ignore + values = self._reorder(values) + batch = self.get_chunks(values, n=n, fn=batch_fn) + yield from batch + elif self._group_by == "contexts": + # Get one sample from each key. + # Select longest continuation per group to ensure sufficient context logits + values = self._reorder( + [ + max(value, key=lambda x: len(x[1][-1])) + for value in self._arr_with_indices.values() + ] + ) + batch = self.get_chunks(values, n=n, fn=batch_fn) + yield from batch + else: + values = self._reorder(self._arr_with_indices) # type: ignore + batch = self.get_chunks(values, n=n, fn=batch_fn) + yield from batch + + def get_cache( + self, + req_str: Tuple[str, str] = None, + cxt_toks: List[int] = None, + cont_toks: List[int] = None, + logits: torch.Tensor = None, + ) -> Iterator[Tuple[Tuple[str, str], List[int], torch.Tensor]]: + """ + Retrieves cached single-token continuations and their associated arguments, updating indices as necessary. + + The behavior of this function varies depending on how the `group_by` attribute is set: + + - When `group_by` is "contexts": + The function identifies single-token continuations by checking for keys that equate to + [context+continuation][-1] and logs the indices for re-ordering. + In this mode, this function can work in two scenarios: + + 1. Cache Hit - Single Match: + If a single matching context-continuation pair is found in the cache, + the function yields the original arguments. + + 2. Cache Hit - Multiple Matches: + If multiple matching context-continuation pairs are found in the cache, + the function expands the logits batch dimension to match the number of cache hits. + It updates the original requests and continuation tokens. + + - When `group_by` is not set to "contexts": + This method yields the original arguments, logits and continuation tokens, + without checking for one-token continuations. + + Parameters: + - req_str (tuple[str, str]): Original strings used for CachingLM. + - cxt_toks (list[int]): Full context tokens used for lookup. + - cont_toks (list[int]): Continuation tokens for which logits were generated. + - logits (torch.Tensor [1, seq_length, vocab_size]): Logits generated by the model given context and continuation keys. + + Yields: + - Iterator: + - req_str (tuple[str, str]): strings used for CachingLM. + - cont_toks (list[int]) : continuation tokens. + - logits (torch.Tensor [1, seq_length, vocab_size]): The original logits (repeated cache hit times) + """ + if self._group_by == "contexts": + cache_hit: List[ + Tuple[int, Tuple[Tuple[str, str], List[int], List[int]]] + ] = self._arr_with_indices.pop(tuple(cxt_toks + cont_toks[:-1])) + if (cache_size := len(cache_hit)) == 1: + self._reorder_indices.extend(x[0] for x in cache_hit) + yield req_str, cont_toks, logits + else: + # If we have matching requests then expand the batch dimension (no-op) and + # yield each along with its corresponding args. + multilogits = logits.expand(cache_size, -1, -1).chunk(cache_size) + indices, req_str, cont_toks = zip( + *[(x[0], x[1][0], x[-1][-1]) for x in cache_hit] + ) + self._reorder_indices.extend(indices) + for c_key, cont_tok, logit in zip(req_str, cont_toks, multilogits): + yield c_key, cont_tok, logit + else: + yield req_str, cont_toks, logits + + def _reorder(self, arr: Union[List, Tuple[Tuple[int, Any], ...]]) -> Iterator: + """ + Reorders the elements in the array based on the sorting function. + + Parameters: + - arr (list | tuple[tuple[int, Any], ...]]): The array or iterable to be reordered. + + Yields: + Iterator + """ + arr = sorted(arr, key=self._sort_fn) + if not self._group_by == "contexts": + # If grouped by contexts then indices will be set in get_cache() + self._reorder_indices.extend([x[0] for x in arr]) + yield from [x[1] for x in arr] + + def get_original(self, newarr: List) -> List: + """ + Restores the original order of elements from the reordered list. + + Parameters: + - newarr (list): The reordered array. + + Returns: + list: The array with elements restored to their original order. + """ + res = [None] * self._size + cov = [False] * self._size + + for ind, v in zip(self._reorder_indices, newarr): + res[ind] = v + cov[ind] = True + + assert all(cov) + + return res + + def __len__(self): + return self._size + + @staticmethod + def group( + arr: Iterable, + fn: Callable, + group_by: Literal["gen_kwargs", "contexts"] = "gen_kwargs", + ) -> dict: + """ + Groups elements of an iterable based on a provided function. + + + The `group_by` parameter determines the method of grouping. + If `group_by` is "contexts", the elements are grouped by [context + cont][:-1]. + If `group_by` is "gen_kwargs", the elements are grouped based on the gen_kwargs dict. + + Parameters: + - arr (Iterable): The iterable to be grouped. + - fn (Callable): The function to determine the grouping. + - values (bool): If True, returns the values of the group. Defaults to False. + + Returns: + Iterator: An iterable of grouped elements. + """ + res = collections.defaultdict(list) + for ob in arr: + # where ob == [context + cont] + if group_by == "contexts": + res[tuple(fn(ob))].append(ob) + else: + try: + hashable_dict = tuple( + ( + key, + tuple(value) + if isinstance(value, collections.abc.Iterable) + else value, + ) + for key, value in sorted(fn(ob).items()) + ) + res[hashable_dict].append(ob) + except (TypeError, AttributeError): + res[tuple(fn(ob))].append(ob) + return res + + @staticmethod + def get_chunks(_iter, n: int = 0, fn=None): + """ + Divides an iterable into chunks of specified size or based on a given function. + Useful for batching + + Parameters: + - iter: The input iterable to be divided into chunks. + - n: An integer representing the size of each chunk. Default is 0. + - fn: A function that takes the current index and the iterable as arguments and returns the size of the chunk. Default is None. + + Returns: + An iterator that yields chunks of the input iterable. + + Example usage: + ``` + data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + for chunk in chunks(data, 3): + print(chunk) + ``` + Output: + ``` + [1, 2, 3] + [4, 5, 6] + [7, 8, 9] + [10] + ``` + """ + arr = [] + _iter = tuple(_iter) + for i, x in enumerate(_iter): + arr.append(x) + if len(arr) == (fn(i, _iter) if fn else n): + yield arr + arr = [] + + if arr: + yield arr + + +def configure_pad_token( + tokenizer: "PreTrainedTokenizerBase", + model_config: Optional["PretrainedConfig"] = None, +) -> "PreTrainedTokenizerBase": + """ + This function checks if the (Hugging Face) tokenizer has a padding token and sets it if not present. + Some tokenizers require special handling. + + Args: + tokenizer: The tokenizer for which the padding token is to be handled. + model_config: The configuration of the model. Default is None. + + Returns: + The tokenizer after the padding token has been handled. + + Raises: + AssertionError: If the tokenizer is of type RWKVWorldTokenizer or Rwkv5Tokenizer and the padding token id is not 0. + """ + if tokenizer.pad_token: + pass + elif tokenizer.unk_token: + tokenizer.pad_token_id = tokenizer.unk_token_id + elif tokenizer.eos_token: + tokenizer.pad_token_id = tokenizer.eos_token_id + else: + # handle special cases + if model_config and getattr(model_config, "model_type", None) == "qwen": + # Qwen's trust_remote_code tokenizer does not allow for adding special tokens + tokenizer.pad_token = "<|endoftext|>" + elif ( + tokenizer.__class__.__name__ == "RWKVWorldTokenizer" + or tokenizer.__class__.__name__ == "Rwkv5Tokenizer" + ): + # The RWKV world tokenizer, does not allow for adding special tokens / setting the pad token (which is set as 0) + # The additional tokenizer name check is needed, as there exists rwkv4 models with neox tokenizer + # --- + # Note that the world tokenizer class name, might change in the future for the final huggingface merge + # https://github.com/huggingface/transformers/pull/26963 + assert tokenizer.pad_token_id == 0 + else: + tokenizer.add_special_tokens({"pad_token": "<|pad|>"}) + + return tokenizer + + +def replace_placeholders( + string: str, default_placeholder: str, image_token: str, max_images: int +): + """ + A utility function used for local multimodal models. It locates all `placeholder` string + occurrences in the given input `string_` and replaces the first `max_count` instances with + `replacement`, and all subsequent occurrences with the empty string. + + This is used to replace placeholder tags by model-specific image tokens like <|image_pad|> + and to allow for only the first `max_count` images to be passed to a model if desired. + + :param string: The original string containing placeholders. + :param default_placeholder: The placeholder text to be replaced. + :param image_token: The token to replace the placeholder with. + :param max_images: The maximum number of replacements to make. + :return: The string with placeholders replaced. + """ + count = 0 + result = [] + + parts = string.split(default_placeholder) + for part in parts[:-1]: # Iterate through all but the last part + result.append(part) + if count < max_images: + result.append(image_token) + count += 1 + elif default_placeholder != image_token: + result.append(default_placeholder) + + # Add the last part of the string + result.append(parts[-1]) + return "".join(result) + + +def flatten_image_list(images: List[List]): + """ + Takes in a list of lists of images, and returns a single list of all images in order. + Used for some multimodal models like Llava-1.5 which expects this flattened-list format for its image processor. + + :param images: A list of lists of PIL images. + :return: a list of PIL images, via concatenating all the sub-lists in order. + """ + return [image for image_list in images for image in image_list] + + +def handle_stop_sequences( + until: Union[str, List[str], None], eos: Optional[str] +) -> List[str]: + """Ensures that the `until` parameter is a list of stop sequences and includes the EOS token.""" + if isinstance(until, str): + until = [until] + elif until is None: + until = [] + elif not isinstance(until, list): + raise ValueError( + f"Expected `kwargs['until']` to be of type Union[str,list] but got {until}" + ) + + if eos is not None and eos not in until: + until.append(eos) + return until + + +def resize_image( + image: "Image.Image", + width: Optional[int] = None, + height: Optional[int] = None, + max_dimension: Optional[int] = None, + keep_aspect_ratio: bool = True, + resample_filter: Union[int, str] = "Image.BICUBIC", + min_width: int = 1, + min_height: int = 1, +) -> "Image.Image": + """ + Resizes a PIL Image object with flexible options. + + Args: + image: The PIL Image object to resize. + width: Target width in pixels. + height: Target height in pixels. + max_dimension: Maximum size for the longer dimension of the image. + keep_aspect_ratio: If True (default) and both width and height are provided, + the image is resized to fit within these dimensions while + maintaining its aspect ratio. If False, the image is stretched + to the exact width and height. + resample_filter: The resampling filter to use for resizing. + Defaults to Image.BICUBIC. + min_width: Minimum width for the resized image. Defaults to 1. + min_height: Minimum height for the resized image. Defaults to 1. + + Returns: + The resized PIL Image object. If no resize parameters are provided + or if the image already meets the criteria, the original image is returned. + + Order of precedence for resizing: + 1. If width AND height are provided: + - If keep_aspect_ratio is True: Fits image within bounds, preserving aspect ratio. + - If keep_aspect_ratio is False: Resizes to exact dimensions (may distort). + 2. Else if only width is provided: Calculates height proportionally. + 3. Else if only height is provided: Calculates width proportionally. + 4. Else if max_dimension is provided: Resizes the longest side to max_dimension + and scales the other side proportionally. + 5. If none of the above are provided, returns the original image. + """ + original_width, original_height = image.size + + # If no arguments are provided, return the original image + if width is None and height is None and max_dimension is None: + return image + + new_width = original_width + new_height = original_height + + if width is not None and height is not None: + # No resize needed if image is already smaller than target dimensions + if original_width <= width and original_height <= height: + return image + + if keep_aspect_ratio: + # Calculate the ratio to fit within the target dimensions + ratio = min(width / original_width, height / original_height) + new_width = int(original_width * ratio) + new_height = int(original_height * ratio) + else: + # Stretch to exact dimensions + new_width = width + new_height = height + elif width is not None: + # No resize needed if width is already smaller + if original_width <= width: + return image + # Calculate height proportionally + new_width = width + new_height = int((original_height / original_width) * new_width) + elif height is not None: + # No resize needed if height is already smaller + if original_height <= height: + return image + # Calculate width proportionally + new_height = height + new_width = int((original_width / original_height) * new_height) + elif max_dimension is not None: + # No resize needed if both dimensions are smaller than max_dimension + if max(original_height, original_width) <= max_dimension: + return image + + if original_width > original_height: + # Width is the longer side + new_width = max_dimension + new_height = int((original_height / original_width) * new_width) + else: + # Height is the longer side or sides are equal + new_height = max_dimension + new_width = int((original_width / original_height) * new_height) + + # Ensure dimensions are at least minimum values + new_width = max(min_width, new_width) + new_height = max(min_height, new_height) + + # Perform the resize operation with the calculated dimensions + return image.resize((new_width, new_height), resample_filter) + + +def truncate_tokens( + tokens: List[int], + max_length: int, + tokenizer: "PreTrainedTokenizerBase", + strategy: str = "left", +): + if strategy == "left": + return tokens[-max_length:] + elif strategy == "right": + return tokens[:max_length] + elif strategy == "middle": + # Truncate the middle of the sequence + left_length = max_length // 2 + right_length = max_length - left_length + return tokens[:left_length] + tokens[-right_length:] + return None diff --git a/Prism/LLaDA/LLaDA_Baseline/dllm_eval/models/verifier.py b/Prism/LLaDA/LLaDA_Baseline/dllm_eval/models/verifier.py new file mode 100644 index 0000000000000000000000000000000000000000..1a89831e5ee670740bc3833b09849115971545a2 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Baseline/dllm_eval/models/verifier.py @@ -0,0 +1,154 @@ +import torch +import logging +import ast +import re +import numpy as np +import textwrap + +logger = logging.getLogger(__name__) + +class CodeVerifier: + def __init__(self, model, tokenizer, device="cuda"): + self.model = model + self.tokenizer = tokenizer + self.device = device + + self.yes_ids, self.no_ids = [], [] + for t in ["Yes", " Yes", "YES"]: + ids = self.tokenizer.encode(t, add_special_tokens=False) + if len(ids) > 0: self.yes_ids.append(ids[-1]) + for t in ["No", " No", "NO"]: + ids = self.tokenizer.encode(t, add_special_tokens=False) + if len(ids) > 0: self.no_ids.append(ids[-1]) + + self.yes_ids = list(set(self.yes_ids)) + self.no_ids = list(set(self.no_ids)) + + def _extract_python_code(self, text): + text = text.strip() + match = re.search(r"```python\s*(.*?)```", text, re.DOTALL) + if match: return match.group(1) + match_generic = re.search(r"```\s*(.*?)```", text, re.DOTALL) + if match_generic: return match_generic.group(1) + return text + + def check_syntax(self, code_str): + clean_code = self._extract_python_code(code_str) + try: + if len(clean_code.strip()) < 5: return False + ast.parse(clean_code) + return True + except: + return False + + def compute_confidence(self, logits): + if logits is None: return 0.0 + probs = torch.softmax(logits, dim=-1) + max_probs, _ = torch.max(probs, dim=-1) + log_probs = torch.log(max_probs + 1e-10) + return torch.exp(torch.mean(log_probs)).item() + + def svf_score(self, prompt, code_str, task_type="code"): + + max_len = 2000 + if len(code_str) > max_len: + if task_type == "reasoning": + truncated_code = code_str[:500] + "\n...[truncated]...\n" + code_str[-(max_len-500):] + else: + truncated_code = code_str[-max_len:] + else: + truncated_code = code_str + + if task_type == "code": + prompt_template = f""" + You are an expert programming contest judge. Your task is to evaluate a generated solution for a given problem based on correctness, efficiency, and adherence to constraints. + + [Problem Statement] + {prompt} + [/Problem Statement] + + [Proposed Python Solution] + ```python + {truncated_code} + ``` + [/Proposed Python Solution] + + **Analysis Steps:** + 1. Correctness: Does the core algorithm correctly solve the problem? + 2. Efficiency: Is the time complexity acceptable for the given constraints? + 3. Edge Cases & Constraints: Does the code handle all rules and edge cases? + + **Conclusion**: Based on your analysis, is the solution likely to be fully correct? Answer with a single word: Yes or No. + **Answer:** """ + + elif task_type == "math": + prompt_template = f""" + You are an expert mathematician and competition judge. Your task is to evaluate a proposed mathematical solution for a given problem based on its logical rigor and accuracy. + + [Math Problem] + {prompt} + [/Math Problem] + + [Proposed Mathematical Solution] + {truncated_code} + [/Proposed Mathematical Solution] + + **Analysis Steps:** + 1. Reasoning Validity: Are the logical steps and mathematical properties applied correctly? + 2. Calculation Accuracy: Are the intermediate calculations or algebraic manipulations accurate? + 3. Goal Alignment: Does the current reasoning path directly lead toward the final answer required by the problem? + + **Conclusion**: Based on your analysis, is this solution path sound and likely to result in the correct final answer? Answer with a single word: Yes or No. + **Answer:** """ + + elif task_type == "reasoning": + prompt_template = f""" + You are an expert reading comprehension and faithfulness judge. Your task is to evaluate a generated answer based on the provided context and question. + + [Context and Question] + {prompt} + [/Context and Question] + + [Proposed Answer] + {truncated_code} + [/Proposed Answer] + + **Analysis Steps :** + 1. Faithfulness: Is the answer an exact, literal span from the context? + 2. Relevance: Does the answer directly address the specific question asked without hallucinating external information? + 3. Accuracy: Does the provided context strictly support this answer? + + **Conclusion**: Based on your analysis, is the answer fully faithful to the context and correct? Answer with a single word: Yes or No. + **Answer:** """ + + else: + prompt_template = f"Is the following answer correct?\nQuestion: {prompt}\nAnswer: {truncated_code}\nAnswer Yes or No.\nAnswer:" + + verify_text = textwrap.dedent(prompt_template).strip() + input_ids = self.tokenizer(verify_text, return_tensors="pt").input_ids.to(self.device) + + max_pos = getattr(self.model.config, "max_position_embeddings", + getattr(self.model.config, "n_positions", + getattr(self.model.config, "max_sequence_length", 20480))) + + if input_ids.shape[1] > max_pos - 16: + logger.warning("Verifier input is too long, truncating from the left.") + input_ids = input_ids[:, -(max_pos - 16):] + + with torch.no_grad(): + outputs = self.model(input_ids) + logits = outputs.logits[0, -1, :] + + yes_score = max((logits[i].item() for i in self.yes_ids if i < logits.shape[-1]), default=-float('inf')) + no_score = max((logits[i].item() for i in self.no_ids if i < logits.shape[-1]), default=-float('inf')) + + if yes_score == -float('inf') and no_score == -float('inf'): return 0.5 + + probs = torch.softmax(torch.tensor([yes_score, no_score]), dim=0) + return probs[0].item() + + def get_reward(self, prompt, code_str, mode="confidence", problem_data=None, current_logits=None, task_type="code"): + if mode == "svf": + return self.svf_score(prompt, code_str, task_type=task_type) + else: + return self.compute_confidence(current_logits) \ No newline at end of file diff --git a/Prism/LLaDA/LLaDA_Baseline/dllm_eval/prompts/__init__.py b/Prism/LLaDA/LLaDA_Baseline/dllm_eval/prompts/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a0c2ce897dcde522ac82d0cbe0e06db1e02b1b72 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Baseline/dllm_eval/prompts/__init__.py @@ -0,0 +1,128 @@ +import ast +import logging +import os +from typing import Dict + +from dllm_eval import utils + + +eval_logger = logging.getLogger(__name__) + +# Prompt library. +# Stores prompts in a dictionary indexed by 2 levels: +# prompt category name, and prompt name. +# This allows us to access prompts +PROMPT_REGISTRY: Dict[str, Dict[str, str]] = { + "qa-basic": { + "question-newline-answer": "Question: {{question}}\nAnswer:", + "q-newline-a": "Q: {{question}}\nA:", + }, +} + + +def get_prompt(prompt_id: str, dataset_name: str = None, subset_name: str = None): + # unpack prompt name + category_name, prompt_name = prompt_id.split(":") + if subset_name is None: + dataset_full_name = dataset_name + else: + dataset_full_name = f"{dataset_name}-{subset_name}" + eval_logger.info(f"Loading prompt from {category_name} for {dataset_full_name}") + if category_name == "promptsource": + try: + from promptsource.templates import DatasetTemplates + except ModuleNotFoundError as exception: + raise type(exception)( + "Tried to load a Promptsource template, but promptsource is not installed ", + "please install promptsource via pip install lm-eval[promptsource] or pip install -e .[promptsource]", + ) + try: + if subset_name is None: + prompts = DatasetTemplates(dataset_name=dataset_name) + else: + prompts = DatasetTemplates( + dataset_name=dataset_name, subset_name=subset_name + ) + except Exception: + raise ValueError(f"{dataset_name} and {subset_name} not found") + if prompt_name in prompts.all_template_names: + return prompts[prompt_name] + else: + raise ValueError( + f"{prompt_name} not in prompt list {prompts.all_template_names}" + ) + elif ".yaml" in category_name: + import yaml + + with open(category_name, "rb") as file: + prompt_yaml_file = yaml.full_load(file) + + prompt_string = prompt_yaml_file["prompts"][prompt_name] + return PromptString(prompt_string) + else: + try: + return PROMPT_REGISTRY[category_name][prompt_name] + except Exception: + raise ValueError( + f"expected only a single `:` as separator between \ + prompt category and name, but got `{prompt_id}` instead" + ) + + +def load_prompt_list( + use_prompt: str, dataset_name=None, subset_name=None, yaml_path=None, **kwargs +): + category_name, prompt_name = use_prompt.split(":") + + if category_name == "promptsource": + from promptsource.templates import DatasetTemplates + + if subset_name is None: + prompts = DatasetTemplates(dataset_name=dataset_name) + else: + prompts = DatasetTemplates( + dataset_name=dataset_name, subset_name=subset_name + ) + + prompt_list = utils.pattern_match(prompt_name, prompts.all_template_names) + + elif ".yaml" in category_name: + import yaml + + if yaml_path is not None: + category_name = os.path.realpath(os.path.join(yaml_path, category_name)) + + with open(category_name, "rb") as file: + prompt_yaml_file = yaml.full_load(file) + + prompt_list = utils.pattern_match( + prompt_name, prompt_yaml_file["prompts"].keys() + ) + + # category_name, *prompt_name = use_prompt.split(":") + # TODO allow to multiple prompt naming + # if len(prompt_name) > 1: + # prompt_list = [] + # for prompt in prompt_name: + # prompt_list.append(utils.pattern_match(prompt_name, prompts.all_template_names)) + # else: + # prompt_list = utils.pattern_match(prompt_name, prompts.all_template_names) + return [":".join([category_name, prompt]) for prompt in prompt_list] + + +class PromptString: + def __init__(self, prompt_string): + self.prompt_string = prompt_string + + def apply(self, doc): + doc_to_text = self.prompt_string["doc_to_text"] + doc_to_target = self.prompt_string["doc_to_target"] + + # TODO need a way to process doc_to_choice + if "doc_to_choice" in self.prompt_string: + raise NotImplementedError("Not yet implemented to accept doc_to_choice") + + text_string = utils.apply_template(doc_to_text, doc) + target_string = utils.apply_template(doc_to_target, doc) + + return [text_string, target_string] diff --git a/Prism/LLaDA/LLaDA_Baseline/dllm_eval/tasks/__init__.py b/Prism/LLaDA/LLaDA_Baseline/dllm_eval/tasks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..73d896452e06c2cc2909c290de70dcf87b0c6f90 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Baseline/dllm_eval/tasks/__init__.py @@ -0,0 +1,670 @@ +import collections +import inspect +import logging +import os +from functools import partial +from typing import Dict, List, Mapping, Optional, Union + +from dllm_eval import utils +from dllm_eval.api.group import ConfigurableGroup, GroupConfig +from dllm_eval.api.task import ConfigurableTask, Task +from dllm_eval.evaluator_utils import get_subtask_list + + +GROUP_ONLY_KEYS = list(GroupConfig().to_dict().keys()) + +eval_logger = logging.getLogger(__name__) + + +class TaskManager: + """TaskManager indexes all tasks from the default `dllm_eval/tasks/` + and an optional directory if provided. + + """ + + def __init__( + self, + verbosity: Optional[str] = None, + include_path: Optional[Union[str, List]] = None, + include_defaults: bool = True, + metadata: Optional[dict] = None, + ) -> None: + if verbosity is not None: + utils.setup_logging(verbosity) + self.include_path = include_path + self.metadata = metadata + self._task_index = self.initialize_tasks( + include_path=include_path, include_defaults=include_defaults + ) + self._all_tasks = sorted(list(self._task_index.keys())) + + self._all_groups = sorted( + [x for x in self._all_tasks if self._task_index[x]["type"] == "group"] + ) + self._all_subtasks = sorted( + [ + x + for x in self._all_tasks + if self._task_index[x]["type"] in ["task", "python_task"] + ] + ) + self._all_tags = sorted( + [x for x in self._all_tasks if self._task_index[x]["type"] == "tag"] + ) + + self.task_group_map = collections.defaultdict(list) + + def initialize_tasks( + self, + include_path: Optional[Union[str, List]] = None, + include_defaults: bool = True, + ) -> dict[str, dict]: + """Creates a dictionary of tasks indexes. + + :param include_path: Union[str, List] = None + An additional path to be searched for tasks recursively. + Can provide more than one such path as a list. + :param include_defaults: bool = True + If set to false, default tasks (those in dllm_eval/tasks/) are not indexed. + return + Dictionary of task names as key and task metadata + """ + if include_defaults: + all_paths = [os.path.dirname(os.path.abspath(__file__)) + "/"] + else: + all_paths = [] + if include_path is not None: + if isinstance(include_path, str): + include_path = [include_path] + all_paths.extend(include_path) + + task_index = {} + for task_dir in all_paths: + tasks = self._get_task_and_group(task_dir) + task_index = {**tasks, **task_index} + + return task_index + + @property + def all_tasks(self): + return self._all_tasks + + @property + def all_groups(self): + return self._all_groups + + @property + def all_subtasks(self): + return self._all_subtasks + + @property + def all_tags(self): + return self._all_tags + + @property + def task_index(self): + return self._task_index + + def list_all_tasks( + self, list_groups=True, list_tags=True, list_subtasks=True + ) -> str: + from pytablewriter import MarkdownTableWriter + + def sanitize_path(path): + # don't print full path if we are within the dllm_eval/tasks dir ! + # if we aren't though, provide the full path. + if "dllm_eval/tasks/" in path: + return "dllm_eval/tasks/" + path.split("dllm_eval/tasks/")[-1] + else: + return path + + group_table = MarkdownTableWriter() + group_table.headers = ["Group", "Config Location"] + gt_values = [] + for g in self.all_groups: + path = self.task_index[g]["yaml_path"] + if path == -1: + path = "---" + else: + path = sanitize_path(path) + gt_values.append([g, path]) + group_table.value_matrix = gt_values + + tag_table = MarkdownTableWriter() + tag_table.headers = ["Tag"] + tag_table.value_matrix = [[t] for t in self.all_tags] + + subtask_table = MarkdownTableWriter() + subtask_table.headers = ["Task", "Config Location", "Output Type"] + st_values = [] + for t in self.all_subtasks: + path = self.task_index[t]["yaml_path"] + + output_type = "" + + # read the yaml file to determine the output type + if path != -1: + config = utils.load_yaml_config(path, mode="simple") + if "output_type" in config: + output_type = config["output_type"] + elif ( + "include" in config + ): # if no output type, check if there is an include with an output type + include_path = path.split("/")[:-1] + config["include"] + include_config = utils.load_yaml_config(include_path, mode="simple") + if "output_type" in include_config: + output_type = include_config["output_type"] + + if path == -1: + path = "---" + else: + path = sanitize_path(path) + st_values.append([t, path, output_type]) + subtask_table.value_matrix = st_values + + result = "\n" + if list_groups: + result += group_table.dumps() + "\n\n" + if list_tags: + result += tag_table.dumps() + "\n\n" + if list_subtasks: + result += subtask_table.dumps() + "\n\n" + return result + + def match_tasks(self, task_list: list[str]) -> list[str]: + return utils.pattern_match(task_list, self.all_tasks) + + def _name_is_registered(self, name: str) -> bool: + if name in self.all_tasks: + return True + return False + + def _name_is_task(self, name: str) -> bool: + if self._name_is_registered(name) and (self.task_index[name]["type"] == "task"): + return True + return False + + def _name_is_tag(self, name: str) -> bool: + if self._name_is_registered(name) and (self.task_index[name]["type"] == "tag"): + return True + return False + + def _name_is_group(self, name: str) -> bool: + if self._name_is_registered(name) and ( + self.task_index[name]["type"] == "group" + ): + return True + return False + + def _name_is_python_task(self, name: str) -> bool: + if self._name_is_registered(name) and ( + self.task_index[name]["type"] == "python_task" + ): + return True + return False + + def _config_is_task(self, config: dict) -> bool: + if ("task" in config) and isinstance(config["task"], str): + return True + return False + + def _config_is_group(self, config: dict) -> bool: + if ("task" in config) and isinstance(config["task"], list): + return True + return False + + def _config_is_python_task(self, config: dict) -> bool: + if "class" in config: + return True + return False + + def _get_yaml_path(self, name: str): + if name not in self.task_index: + raise ValueError + return self.task_index[name]["yaml_path"] + + def _get_config(self, name): + if name not in self.task_index: + raise ValueError + yaml_path = self._get_yaml_path(name) + if yaml_path == -1: + return {} + else: + return utils.load_yaml_config(yaml_path, mode="full") + + def _get_tasklist(self, name): + if self._name_is_task(name): + raise ValueError + return self.task_index[name]["task"] + + def _process_alias(self, config, group=None): + # If the group is not the same as the original + # group which the group alias was intended for, + # Set the group_alias to None instead. + if ("group_alias" in config) and ("group" in config) and group is not None: + if config["group"] != group: + config["group_alias"] = None + return config + + def _class_has_config_in_constructor(self, cls): + constructor = getattr(cls, "__init__", None) + return ( + "config" in inspect.signature(constructor).parameters + if constructor + else False + ) + + def _load_individual_task_or_group( + self, + name_or_config: Optional[Union[str, dict]] = None, + parent_name: Optional[str] = None, + update_config: Optional[dict] = None, + ) -> Mapping: + def _load_task(config, task): + if "include" in config: + config = { + **utils.load_yaml_config( + yaml_path=None, + yaml_config={"include": config.pop("include")}, + mode="full", + ), + **config, + } + if self._config_is_python_task(config): + if self._class_has_config_in_constructor(config["class"]): + task_object = config["class"](config=config) + else: + task_object = config["class"]() + if isinstance(task_object, ConfigurableTask): + # very scuffed: set task name here. TODO: fixme? + task_object.config.task = task + else: + if self.metadata is not None: + config["metadata"] = config.get("metadata", {}) | self.metadata + else: + config["metadata"] = config.get("metadata", {}) + task_object = ConfigurableTask(config=config) + + return {task: task_object} + + def _get_group_and_subtask_from_config( + config: dict, + ) -> tuple[ConfigurableGroup, list[str]]: + if self.metadata is not None: + config["metadata"] = config.get("metadata", {}) | self.metadata + group_name = ConfigurableGroup(config=config) + subtask_list = [] + for task in group_name.config["task"]: + if isinstance(task, str) and self._name_is_tag(task): + subtask_list.extend(self._get_tasklist(task)) + else: + subtask_list.append(task) + return group_name, subtask_list + + def _process_group_config( + config: dict, update_config: dict = None + ) -> tuple[dict, dict]: + if update_config is not None: + config = {**config, **update_config} + _update_config = { + k: v for k, v in config.items() if k not in GROUP_ONLY_KEYS + } + if not bool(_update_config): + _update_config = None + + group_config = {k: v for k, v in config.items() if k in GROUP_ONLY_KEYS} + return group_config, _update_config + + if isinstance(name_or_config, str): + if update_config is not None: + # Process name_or_config as a dict instead + name_or_config = {"task": name_or_config, **update_config} + elif self._name_is_task(name_or_config) or self._name_is_python_task( + name_or_config + ): + task_config = self._get_config(name_or_config) + return _load_task(task_config, task=name_or_config) + else: + subtask_list = self._get_tasklist(name_or_config) + if subtask_list == -1: + group_config = self._get_config(name_or_config) + group_config, update_config = _process_group_config(group_config) + group_name, subtask_list = _get_group_and_subtask_from_config( + group_config + ) + else: + if self._name_is_tag(name_or_config): + fn = partial( + self._load_individual_task_or_group, + update_config=name_or_config + if isinstance(name_or_config, dict) + else None, + ) + return dict( + collections.ChainMap(*map(fn, reversed(subtask_list))) + ) + else: + group_name = ConfigurableGroup( + config={"group": name_or_config, "task": subtask_list} + ) + + if isinstance(name_or_config, dict): + if self._config_is_task(name_or_config): + name = name_or_config.pop("task") + if update_config is not None: + name_or_config = {**name_or_config, **update_config} + # If the name is registered as a group + if self._name_is_group(name): + group_config = self._get_config(name) + + group_config, update_config = _process_group_config( + group_config, name_or_config + ) + group_name, subtask_list = _get_group_and_subtask_from_config( + group_config + ) + elif self._name_is_tag(name): + subtask_list = self._get_tasklist(name) + fn = partial( + self._load_individual_task_or_group, + update_config=name_or_config, + ) + return dict(collections.ChainMap(*map(fn, reversed(subtask_list)))) + else: + if self._name_is_registered(name): + base_task_config = self._get_config(name) + + # Check if this is a duplicate. + if parent_name is not None: + num_duplicate = len( + list( + filter( + lambda x: x.startswith(name), + self.task_group_map[parent_name], + ) + ) + ) + if num_duplicate > 0: + name = f"{name}-{num_duplicate}" + self.task_group_map[parent_name].append(name) + + task_config = { + **base_task_config, + **name_or_config, + } + else: + task_config = name_or_config + return _load_task(task_config, task=name) + else: + group_config, update_config = _process_group_config(name_or_config) + group_name, subtask_list = _get_group_and_subtask_from_config( + group_config + ) + + fn = partial( + self._load_individual_task_or_group, + parent_name=group_name, + update_config=update_config, + ) + return { + group_name: dict(collections.ChainMap(*map(fn, reversed(subtask_list)))) + } + + def load_task_or_group(self, task_list: Optional[Union[str, list]] = None) -> dict: + """Loads a dictionary of task objects from a list + + :param task_list: Union[str, list] = None + Single string or list of string of task names to be loaded + + :return + Dictionary of task objects + """ + if isinstance(task_list, str): + task_list = [task_list] + + all_loaded_tasks = dict( + collections.ChainMap( + *map( + lambda task: self._load_individual_task_or_group(task), + task_list, + ) + ) + ) + return all_loaded_tasks + + def load_config(self, config: Dict): + return self._load_individual_task_or_group(config) + + def _get_task_and_group(self, task_dir: str): + """Creates a dictionary of tasks index with the following metadata, + - `type`, that can be either `task`, `python_task`, `group` or `tags`. + `task` refer to regular task configs, `python_task` are special + yaml files that only consists of `task` and `class` parameters. + `group` are group configs. `tags` are labels that can be assigned + to tasks to assist in sorting and calling tasks of certain themes. + - `yaml_path`, path to the yaml file. If the entry is a `group` that + was configured through a task config, the yaml_path will be -1 + and all subtasks will be listed in `task` (see below) + - `task`, reserved for entries with `type` as `group`. This will list + all subtasks. When a group config is created (as opposed to task + config having `group` parameter set), this will be set to -1 to + avoid recursive indexing. The whole list of subtasks will be loaded + at evaluation. + + :param task_dir: str + A directory to check for tasks + + :return + Dictionary of task names as key and task metadata + """ + + def _populate_tags_and_groups(config, task, tasks_and_groups, print_info): + # TODO: remove group in next release + if "tag" in config: + attr_list = config["tag"] + if isinstance(attr_list, str): + attr_list = [attr_list] + + for tag in attr_list: + if tag not in tasks_and_groups: + tasks_and_groups[tag] = { + "type": "tag", + "task": [task], + "yaml_path": -1, + } + elif tasks_and_groups[tag]["type"] != "tag": + eval_logger.info( + f"The tag '{tag}' is already registered as a group, this tag will not be registered. " + "This may affect tasks you want to call." + ) + break + else: + tasks_and_groups[tag]["task"].append(task) + + # TODO: remove group in next release + print_info = True + ignore_dirs = [ + "__pycache__", + ".ipynb_checkpoints", + ] + tasks_and_groups = collections.defaultdict() + for root, dirs, file_list in os.walk(task_dir): + dirs[:] = [d for d in dirs if d not in ignore_dirs] + for f in file_list: + if f.endswith(".yaml"): + yaml_path = os.path.join(root, f) + print(yaml_path) + config = utils.load_yaml_config(yaml_path, mode="simple") + if self._config_is_python_task(config): + # This is a python class config + task = config["task"] + tasks_and_groups[task] = { + "type": "python_task", + "yaml_path": yaml_path, + } + _populate_tags_and_groups( + config, task, tasks_and_groups, print_info + ) + elif self._config_is_group(config): + # This is a group config + tasks_and_groups[config["group"]] = { + "type": "group", + "task": -1, # This signals that + # we don't need to know + # the task list for indexing + # as it can be loaded + # when called. + "yaml_path": yaml_path, + } + + # # Registered the level 1 tasks from a group config + # for config in config["task"]: + # if isinstance(config, dict) and self._config_is_task(config): + # task = config["task"] + # tasks_and_groups[task] = { + # "type": "task", + # "yaml_path": yaml_path, + # } + + elif self._config_is_task(config): + # This is a task config + task = config["task"] + tasks_and_groups[task] = { + "type": "task", + "yaml_path": yaml_path, + } + _populate_tags_and_groups( + config, task, tasks_and_groups, print_info + ) + else: + eval_logger.debug(f"File {f} in {root} could not be loaded") + + return tasks_and_groups + + +def get_task_name_from_config(task_config: Dict[str, str]) -> str: + if "task" in task_config: + return task_config["task"] + if "dataset_name" in task_config: + return "{dataset_path}_{dataset_name}".format(**task_config) + else: + return "{dataset_path}".format(**task_config) + + +def get_task_name_from_object(task_object): + if hasattr(task_object, "config"): + return task_object._config["task"] + + # TODO: scrap this + # this gives a mechanism for non-registered tasks to have a custom name anyways when reporting + return ( + task_object.EVAL_HARNESS_NAME + if hasattr(task_object, "EVAL_HARNESS_NAME") + else type(task_object).__name__ + ) + + +def _check_duplicates(task_dict: dict) -> None: + """helper function solely used in validating get_task_dict output. + Takes the output of dllm_eval.evaluator_utils.get_subtask_list and + returns a list of all leaf subtasks contained within, and errors if any such leaf subtasks are + "oversubscribed" to several disjoint groups. + """ + subtask_names = [] + for key, value in task_dict.items(): + subtask_names.extend(value) + + duplicate_tasks = { + task_name for task_name in subtask_names if subtask_names.count(task_name) > 1 + } + + # locate the potentially problematic groups that seem to 'compete' for constituent subtasks + competing_groups = [ + group + for group in task_dict.keys() + if len(set(task_dict[group]).intersection(duplicate_tasks)) > 0 + ] + + if len(duplicate_tasks) > 0: + raise ValueError( + f"Found 1 or more tasks while trying to call get_task_dict() that were members of more than 1 called group: {list(duplicate_tasks)}. Offending groups: {competing_groups}. Please call groups which overlap their constituent tasks in separate evaluation runs." + ) + + +def get_task_dict( + task_name_list: Union[str, List[Union[str, Dict, Task]]], + task_manager: Optional[TaskManager] = None, +): + """Creates a dictionary of task objects from either a name of task, config, or prepared Task object. + + :param task_name_list: List[Union[str, Dict, Task]] + Name of model or LM object, see dllm_eval.models.get_model + :param task_manager: TaskManager = None + A TaskManager object that stores indexed tasks. If not set, + task_manager will load one. This should be set by the user + if there are additional paths that want to be included + via `include_path` + + :return + Dictionary of task objects + """ + + task_name_from_string_dict = {} + task_name_from_config_dict = {} + task_name_from_object_dict = {} + + if isinstance(task_name_list, str): + task_name_list = [task_name_list] + elif isinstance(task_name_list, list): + if not all([isinstance(task, (str, dict, Task)) for task in task_name_list]): + raise TypeError( + "Expected all list items to be of types 'str', 'dict', or 'Task', but at least one entry did not match." + ) + else: + raise TypeError( + f"Expected a 'str' or 'list' but received {type(task_name_list)}." + ) + + string_task_name_list = [task for task in task_name_list if isinstance(task, str)] + others_task_name_list = [ + task for task in task_name_list if not isinstance(task, str) + ] + if len(string_task_name_list) > 0: + if task_manager is None: + task_manager = TaskManager() + + task_name_from_string_dict = task_manager.load_task_or_group( + string_task_name_list + ) + + for task_element in others_task_name_list: + if isinstance(task_element, dict): + task_name_from_config_dict = { + **task_name_from_config_dict, + **task_manager.load_config(config=task_element), + } + + elif isinstance(task_element, Task): + task_name_from_object_dict = { + **task_name_from_object_dict, + get_task_name_from_object(task_element): task_element, + } + + if not set(task_name_from_string_dict.keys()).isdisjoint( + set(task_name_from_object_dict.keys()) + ): + raise ValueError + + final_task_dict = { + **task_name_from_string_dict, + **task_name_from_config_dict, + **task_name_from_object_dict, + } + + # behavior can get odd if one tries to invoke several groups that "compete" for the same task. + # (notably, because one could request several num_fewshot values at once in GroupConfig overrides for the subtask + # and we'd be unsure which to use and report.) + # we explicitly check and error in this case. + _check_duplicates(get_subtask_list(final_task_dict)) + + return final_task_dict diff --git a/Prism/LLaDA/LLaDA_Baseline/dllm_eval/tasks/gsm8k/gsm8k.yaml b/Prism/LLaDA/LLaDA_Baseline/dllm_eval/tasks/gsm8k/gsm8k.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8c56206923cf19bac4ec07233c6b0b17ac0460ad --- /dev/null +++ b/Prism/LLaDA/LLaDA_Baseline/dllm_eval/tasks/gsm8k/gsm8k.yaml @@ -0,0 +1,15 @@ +task: gsm8k +dataset_path: openai/gsm8k +dataset_name: main +output_type: generate_until +training_split: train +fewshot_split: train +test_split: test +doc_to_text: !function utils.gsm_prompt +doc_to_target: "{{answer.split('####')[-1].strip()}}" +generation_kwargs: + until: + - "[NO_UNTIL_PLACEHOLDER]" + do_sample: false +repeats: 1 +num_fewshot: 0 diff --git a/Prism/LLaDA/LLaDA_Baseline/dllm_eval/tasks/gsm8k/utils.py b/Prism/LLaDA/LLaDA_Baseline/dllm_eval/tasks/gsm8k/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..8ceaa3d2ab7af89f27e69b470a2f6787f6133519 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Baseline/dllm_eval/tasks/gsm8k/utils.py @@ -0,0 +1,13 @@ +def gsm_prompt(doc): + system_prompt = ( + "You are a math expert. You will be given a question to solve. Solve it step by step. Wrap the final answer in a \\boxed{}. \n" + "Respond in the following format:\n" + "\n" + "Your reasoning here\n" + "\n" + "\n" + "\\boxed{...}\n" + "" + ) + prompt = f"{system_prompt}\n\n{doc['question']}\n\n" + return prompt diff --git a/Prism/LLaDA/LLaDA_Baseline/dllm_eval/tasks/humaneval/humaneval.yaml b/Prism/LLaDA/LLaDA_Baseline/dllm_eval/tasks/humaneval/humaneval.yaml new file mode 100644 index 0000000000000000000000000000000000000000..024d38f0da160e853cd8c3123104a4485677c0fd --- /dev/null +++ b/Prism/LLaDA/LLaDA_Baseline/dllm_eval/tasks/humaneval/humaneval.yaml @@ -0,0 +1,13 @@ +task: humaneval +dataset_path: openai/openai_humaneval +unsafe_code: true +output_type: generate_until +test_split: test +doc_to_text: "Write a solution to the following problem and make sure that it passes the tests:\n{{prompt}}\n\nFirst, reason about the solution step-by-step. Then, write the code.\nRespond in the following format:\n\nYour reasoning here\n\n\n```python\nThe complete implementation of the {{entry_point}} function\n```\n" +doc_to_target: "{{test}}\ncheck({{entry_point}})" +generation_kwargs: + until: + - "[NO_UNTIL_PLACEHOLDER]" + do_sample: false +repeats: 1 +num_fewshot: 0 diff --git a/Prism/LLaDA/LLaDA_Baseline/dllm_eval/tasks/humaneval/utils.py b/Prism/LLaDA/LLaDA_Baseline/dllm_eval/tasks/humaneval/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..11bac61cfa12fad57aacfed28b55bee467cf23e4 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Baseline/dllm_eval/tasks/humaneval/utils.py @@ -0,0 +1,43 @@ +import evaluate as hf_evaluate + + +try: + compute_ = hf_evaluate.load("code_eval") + test_cases = ["assert add(2, 3)==5"] + candidates = [["def add(a,b): return a*b"]] + results = compute_.compute(references=test_cases, predictions=candidates, k=[1]) +except Exception as e: + raise e + + +def pass_at_k(references: list[str], predictions: list[list[str]], k: list[int] = None): + global compute_ + assert k is not None + if isinstance(k, int): + k = [k] + res = compute_.compute( + references=references, + predictions=predictions, + k=k + ) + return res[0] + + +def clean_response_string(r: str) -> str: + cleaned_text = r if r.rfind("```python") == -1 else r[r.rfind("```python"):] + cleaned_text = cleaned_text if cleaned_text.rfind("```") == -1 else cleaned_text[: cleaned_text.rfind("```")] + cleaned_text = cleaned_text if cleaned_text.rfind("if __name__ == \"__main__\":") == -1 else cleaned_text[: cleaned_text.rfind("if __name__ == \"__main__\":")] + return cleaned_text + + +def build_predictions(resps: list[list[str]], docs: list[dict]) -> list[list[str]]: + return [[doc["prompt"] + r for r in resp] for resp, doc in zip(resps, docs)] + + +def build_predictions( + resps: list[list[str]], docs: list[dict] +) -> list[list[str]]: + return [ + [clean_response_string(r) for r in resp] + for resp, doc in zip(resps, docs) + ] diff --git a/Prism/LLaDA/LLaDA_Baseline/dllm_eval/tasks/mbpp/mbpp.yaml b/Prism/LLaDA/LLaDA_Baseline/dllm_eval/tasks/mbpp/mbpp.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f5b9755ad30669e2335bd374ba5f53db0572630f --- /dev/null +++ b/Prism/LLaDA/LLaDA_Baseline/dllm_eval/tasks/mbpp/mbpp.yaml @@ -0,0 +1,14 @@ +task: mbpp +dataset_path: google-research-datasets/mbpp +dataset_name: full +unsafe_code: true +output_type: generate_until +test_split: test +doc_to_text: "\n{{text}} Your code should pass these tests:\n\n{{test_list[0]}}\n{{test_list[1]}}\n{{test_list[2]}} \n\nFirst, reason about the solution step-by-step. Then, write the code.\nRespond in the following format:\n\nYour reasoning here\n\n\n```python\nThe complete implementation of the function\n```\n" +doc_to_target: "{% if is_fewshot is defined %}{{code}}\n[DONE]{% else %}{{test_list[0]}}\n{{test_list[1]}}\n{{test_list[2]}}{% endif %}" +target_delimiter: "" +generation_kwargs: + until: + - "[NO_UNTIL_PLACEHOLDER]" + do_sample: false +num_fewshot: 0 diff --git a/Prism/LLaDA/LLaDA_Baseline/dllm_eval/tasks/mbpp/utils.py b/Prism/LLaDA/LLaDA_Baseline/dllm_eval/tasks/mbpp/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..234fc7ed5de047e556dea2ff77d02a232c8f3e6e --- /dev/null +++ b/Prism/LLaDA/LLaDA_Baseline/dllm_eval/tasks/mbpp/utils.py @@ -0,0 +1,79 @@ +import re +from typing import Union + +import evaluate as hf_evaluate + + +try: + pass_at_k = hf_evaluate.load("code_eval") + + # run simple test to check code execution is enabled before model generation + test_cases = ["assert add(2, 3)==5"] + candidates = [["def add(a,b): return a*b"]] + results = pass_at_k.compute(references=test_cases, predictions=candidates, k=[1]) +except Exception as e: + raise e + + +def pass_at_1( + references: Union[str, list[str]], predictions: Union[str, list[list[str]]] +) -> float: + if isinstance(references, str): + references = [references] + if isinstance(predictions[0], str): + predictions = [[p] for p in predictions] + return pass_at_k.compute( + references=references, + predictions=predictions, + k=[1], + num_workers=48 + )[0]["pass@1"] + + +def extract_code_blocks(text: str) -> str: + text = re.sub(r"\[DONE\]", "", text) + text = re.sub(r"<\|eot_id\|>", "", text) + text = re.sub(r"<\|endoftext\|>", "", text) + return text + + +def build_predictions(resps: list[list[str]], docs: list[dict]) -> list[list[str]]: + return [[extract_code_blocks(r) for r in resp] for resp in resps] + + +def list_fewshot_samples(): + return [ + { + "task_id": 2, + "text": "Write a function to find the similar elements from the given two tuple lists.", + "code": "def similar_elements(test_tup1, test_tup2):\r\n res = tuple(set(test_tup1) & set(test_tup2))\r\n return (res) ", + "test_list": [ + "assert similar_elements((3, 4, 5, 6),(5, 7, 4, 10)) == (4, 5)", + "assert similar_elements((1, 2, 3, 4),(5, 4, 3, 7)) == (3, 4)", + "assert similar_elements((11, 12, 14, 13),(17, 15, 14, 13)) == (13, 14)", + ], + "is_fewshot": True, + }, + { + "task_id": 3, + "text": "Write a python function to identify non-prime numbers.", + "code": "import math\r\ndef is_not_prime(n):\r\n result = False\r\n for i in range(2,int(math.sqrt(n)) + 1):\r\n if n % i == 0:\r\n result = True\r\n return result", + "test_list": [ + "assert is_not_prime(2) == False", + "assert is_not_prime(10) == True", + "assert is_not_prime(35) == True", + ], + "is_fewshot": True, + }, + { + "task_id": 4, + "text": "Write a function to find the largest integers from a given list of numbers using heap queue algorithm.", + "code": "import heapq as hq\r\ndef heap_queue_largest(nums,n):\r\n largest_nums = hq.nlargest(n, nums)\r\n return largest_nums", + "test_list": [ + "assert heap_queue_largest( [25, 35, 22, 85, 14, 65, 75, 22, 58],3)==[85, 75, 65] ", + "assert heap_queue_largest( [25, 35, 22, 85, 14, 65, 75, 22, 58],2)==[85, 75] ", + "assert heap_queue_largest( [25, 35, 22, 85, 14, 65, 75, 22, 58],5)==[85, 75, 65, 58, 35]", + ], + "is_fewshot": True, + }, + ] diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/certifi/__main__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/certifi/__main__.py new file mode 100644 index 0000000000000000000000000000000000000000..8945b5da857f4a7dec2b84f1225f012f6098418c --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/certifi/__main__.py @@ -0,0 +1,12 @@ +import argparse + +from certifi import contents, where + +parser = argparse.ArgumentParser() +parser.add_argument("-c", "--contents", action="store_true") +args = parser.parse_args() + +if args.contents: + print(contents()) +else: + print(where()) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/frozenlist-1.5.0.dist-info/INSTALLER b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/frozenlist-1.5.0.dist-info/INSTALLER new file mode 100644 index 0000000000000000000000000000000000000000..a1b589e38a32041e49332e5e81c2d363dc418d68 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/frozenlist-1.5.0.dist-info/INSTALLER @@ -0,0 +1 @@ +pip diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/frozenlist-1.5.0.dist-info/LICENSE b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/frozenlist-1.5.0.dist-info/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..7082a2d5b9047bfc09589f387053e24ea490bc54 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/frozenlist-1.5.0.dist-info/LICENSE @@ -0,0 +1,201 @@ +Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "{}" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2013-2019 Nikolay Kim and Andrew Svetlov + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/frozenlist-1.5.0.dist-info/METADATA b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/frozenlist-1.5.0.dist-info/METADATA new file mode 100644 index 0000000000000000000000000000000000000000..5b96937af16d2f8fab07eb0cd808e7f9f9d9e509 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/frozenlist-1.5.0.dist-info/METADATA @@ -0,0 +1,477 @@ +Metadata-Version: 2.1 +Name: frozenlist +Version: 1.5.0 +Summary: A list-like structure which implements collections.abc.MutableSequence +Home-page: https://github.com/aio-libs/frozenlist +Maintainer: aiohttp team +Maintainer-email: team@aiohttp.org +License: Apache 2 +Project-URL: Chat: Matrix, https://matrix.to/#/#aio-libs:matrix.org +Project-URL: Chat: Matrix Space, https://matrix.to/#/#aio-libs-space:matrix.org +Project-URL: CI: Github Actions, https://github.com/aio-libs/frozenlist/actions +Project-URL: Code of Conduct, https://github.com/aio-libs/.github/blob/master/CODE_OF_CONDUCT.md +Project-URL: Coverage: codecov, https://codecov.io/github/aio-libs/frozenlist +Project-URL: Docs: Changelog, https://github.com/aio-libs/frozenlist/blob/master/CHANGES.rst#changelog +Project-URL: Docs: RTD, https://frozenlist.aio-libs.org +Project-URL: GitHub: issues, https://github.com/aio-libs/frozenlist/issues +Project-URL: GitHub: repo, https://github.com/aio-libs/frozenlist +Classifier: Development Status :: 5 - Production/Stable +Classifier: Intended Audience :: Developers +Classifier: License :: OSI Approved :: Apache Software License +Classifier: Operating System :: POSIX +Classifier: Operating System :: MacOS :: MacOS X +Classifier: Operating System :: Microsoft :: Windows +Classifier: Programming Language :: Cython +Classifier: Programming Language :: Python +Classifier: Programming Language :: Python :: 3 +Classifier: Programming Language :: Python :: 3.8 +Classifier: Programming Language :: Python :: 3.9 +Classifier: Programming Language :: Python :: 3.10 +Classifier: Programming Language :: Python :: 3.11 +Classifier: Programming Language :: Python :: 3.12 +Classifier: Programming Language :: Python :: 3.13 +Classifier: Programming Language :: Python :: Implementation :: CPython +Classifier: Programming Language :: Python :: Implementation :: PyPy +Requires-Python: >=3.8 +Description-Content-Type: text/x-rst +License-File: LICENSE + +frozenlist +========== + +.. image:: https://github.com/aio-libs/frozenlist/workflows/CI/badge.svg + :target: https://github.com/aio-libs/frozenlist/actions + :alt: GitHub status for master branch + +.. image:: https://codecov.io/gh/aio-libs/frozenlist/branch/master/graph/badge.svg + :target: https://codecov.io/gh/aio-libs/frozenlist + :alt: codecov.io status for master branch + +.. image:: https://img.shields.io/pypi/v/frozenlist.svg?logo=Python&logoColor=white + :target: https://pypi.org/project/frozenlist + :alt: frozenlist @ PyPI + +.. image:: https://readthedocs.org/projects/frozenlist/badge/?version=latest + :target: https://frozenlist.aio-libs.org + :alt: Read The Docs build status badge + +.. image:: https://img.shields.io/matrix/aio-libs:matrix.org?label=Discuss%20on%20Matrix%20at%20%23aio-libs%3Amatrix.org&logo=matrix&server_fqdn=matrix.org&style=flat + :target: https://matrix.to/#/%23aio-libs:matrix.org + :alt: Matrix Room — #aio-libs:matrix.org + +.. image:: https://img.shields.io/matrix/aio-libs-space:matrix.org?label=Discuss%20on%20Matrix%20at%20%23aio-libs-space%3Amatrix.org&logo=matrix&server_fqdn=matrix.org&style=flat + :target: https://matrix.to/#/%23aio-libs-space:matrix.org + :alt: Matrix Space — #aio-libs-space:matrix.org + +Introduction +------------ + +``frozenlist.FrozenList`` is a list-like structure which implements +``collections.abc.MutableSequence``. The list is *mutable* until ``FrozenList.freeze`` +is called, after which list modifications raise ``RuntimeError``: + + +>>> from frozenlist import FrozenList +>>> fl = FrozenList([17, 42]) +>>> fl.append('spam') +>>> fl.append('Vikings') +>>> fl + +>>> fl.freeze() +>>> fl + +>>> fl.frozen +True +>>> fl.append("Monty") +Traceback (most recent call last): + File "", line 1, in + File "frozenlist/_frozenlist.pyx", line 97, in frozenlist._frozenlist.FrozenList.append + self._check_frozen() + File "frozenlist/_frozenlist.pyx", line 19, in frozenlist._frozenlist.FrozenList._check_frozen + raise RuntimeError("Cannot modify frozen list.") +RuntimeError: Cannot modify frozen list. + + +FrozenList is also hashable, but only when frozen. Otherwise it also throws a RuntimeError: + + +>>> fl = FrozenList([17, 42, 'spam']) +>>> hash(fl) +Traceback (most recent call last): + File "", line 1, in + File "frozenlist/_frozenlist.pyx", line 111, in frozenlist._frozenlist.FrozenList.__hash__ + raise RuntimeError("Cannot hash unfrozen list.") +RuntimeError: Cannot hash unfrozen list. +>>> fl.freeze() +>>> hash(fl) +3713081631934410656 +>>> dictionary = {fl: 'Vikings'} # frozen fl can be a dict key +>>> dictionary +{: 'Vikings'} + + +Installation +------------ + +:: + + $ pip install frozenlist + +The library requires Python 3.8 or newer. + + +Documentation +------------- + +https://frozenlist.aio-libs.org + +Communication channels +---------------------- + +We have a *Matrix Space* `#aio-libs-space:matrix.org +`_ which is +also accessible via Gitter. + +Requirements +------------ + +- Python >= 3.8 + +License +------- + +``frozenlist`` is offered under the Apache 2 license. + +Source code +----------- + +The project is hosted on GitHub_ + +Please file an issue in the `bug tracker +`_ if you have found a bug +or have some suggestions to improve the library. + +.. _GitHub: https://github.com/aio-libs/frozenlist + +========= +Changelog +========= + +.. + You should *NOT* be adding new change log entries to this file, this + file is managed by towncrier. You *may* edit previous change logs to + fix problems like typo corrections or such. + To add a new change log entry, please see + https://pip.pypa.io/en/latest/development/contributing/#news-entries + we named the news folder "changes". + + WARNING: Don't drop the next directive! + +.. towncrier release notes start + +1.5.0 (2024-10-22) +================== + +Bug fixes +--------- + +- An incorrect signature of the ``__class_getitem__`` class method + has been fixed, adding a missing ``class_item`` argument under + Python 3.8 and older. + + This change also improves the code coverage of this method that + was previously missing -- by `@webknjaz `__. + + + *Related issues and pull requests on GitHub:* + `#567 `__, `#571 `__. + + +Improved documentation +---------------------- + +- Rendered issue, PR, and commit links now lead to + ``frozenlist``'s repo instead of ``yarl``'s repo. + + + *Related issues and pull requests on GitHub:* + `#573 `__. + +- On the ``Contributing docs`` page, + a link to the ``Towncrier philosophy`` has been fixed. + + + *Related issues and pull requests on GitHub:* + `#574 `__. + + +Packaging updates and notes for downstreams +------------------------------------------- + +- A name of a temporary building directory now reflects + that it's related to ``frozenlist``, not ``yarl``. + + + *Related issues and pull requests on GitHub:* + `#573 `__. + +- Declared Python 3.13 supported officially in the distribution package metadata. + + + *Related issues and pull requests on GitHub:* + `#595 `__. + + +---- + + +1.4.1 (2023-12-15) +================== + +Packaging updates and notes for downstreams +------------------------------------------- + +- Declared Python 3.12 and PyPy 3.8-3.10 supported officially + in the distribution package metadata. + + + *Related issues and pull requests on GitHub:* + `#553 `__. + +- Replaced the packaging is replaced from an old-fashioned ``setup.py`` to an + in-tree `PEP 517 `__ build backend -- by `@webknjaz `__. + + Whenever the end-users or downstream packagers need to build ``frozenlist`` + from source (a Git checkout or an sdist), they may pass a ``config_settings`` + flag ``pure-python``. If this flag is not set, a C-extension will be built + and included into the distribution. + + Here is how this can be done with ``pip``: + + .. code-block:: console + + $ python3 -m pip install . --config-settings=pure-python= + + This will also work with ``-e | --editable``. + + The same can be achieved via ``pypa/build``: + + .. code-block:: console + + $ python3 -m build --config-setting=pure-python= + + Adding ``-w | --wheel`` can force ``pypa/build`` produce a wheel from source + directly, as opposed to building an ``sdist`` and then building from it. + + + *Related issues and pull requests on GitHub:* + `#560 `__. + + +Contributor-facing changes +-------------------------- + +- It is now possible to request line tracing in Cython builds using the + ``with-cython-tracing`` `PEP 517 `__ config setting + -- `@webknjaz `__. + + This can be used in CI and development environment to measure coverage + on Cython modules, but is not normally useful to the end-users or + downstream packagers. + + Here's a usage example: + + .. code-block:: console + + $ python3 -Im pip install . --config-settings=with-cython-tracing=true + + For editable installs, this setting is on by default. Otherwise, it's + off unless requested explicitly. + + The following produces C-files required for the Cython coverage + plugin to map the measurements back to the PYX-files: + + .. code-block:: console + + $ python -Im pip install -e . + + Alternatively, the ``FROZENLIST_CYTHON_TRACING=1`` environment variable + can be set to do the same as the `PEP 517 `__ config setting. + + + *Related issues and pull requests on GitHub:* + `#560 `__. + +- Coverage collection has been implemented for the Cython modules + -- by `@webknjaz `__. + + It will also be reported to Codecov from any non-release CI jobs. + + + *Related issues and pull requests on GitHub:* + `#561 `__. + +- A step-by-step ``Release Guide`` guide has + been added, describing how to release *frozenlist* -- by `@webknjaz `__. + + This is primarily targeting the maintainers. + + + *Related issues and pull requests on GitHub:* + `#563 `__. + +- Detailed ``Contributing Guidelines`` on + authoring the changelog fragments have been published in the + documentation -- by `@webknjaz `__. + + + *Related issues and pull requests on GitHub:* + `#564 `__. + + +---- + + +1.4.0 (2023-07-12) +================== + +The published source distribution package became buildable +under Python 3.12. + + +---- + + +Bugfixes +-------- + +- Removed an unused ``typing.Tuple`` import + `#411 `_ + + +Deprecations and Removals +------------------------- + +- Dropped Python 3.7 support. + `#413 `_ + + +Misc +---- + +- `#410 `_, `#433 `_ + + +---- + + +1.3.3 (2022-11-08) +================== + +- Fixed CI runs when creating a new release, where new towncrier versions + fail when the current version section is already present. + + +---- + + +1.3.2 (2022-11-08) +================== + +Misc +---- + +- Updated the CI runs to better check for test results and to avoid deprecated syntax. `#327 `_ + + +---- + + +1.3.1 (2022-08-02) +================== + +The published source distribution package became buildable +under Python 3.11. + + +---- + + +1.3.0 (2022-01-18) +================== + +Bugfixes +-------- + +- Do not install C sources with binary distributions. + `#250 `_ + + +Deprecations and Removals +------------------------- + +- Dropped Python 3.6 support + `#274 `_ + + +---- + + +1.2.0 (2021-10-16) +================== + +Features +-------- + +- ``FrozenList`` now supports being used as a generic type as per PEP 585, e.g. ``frozen_int_list: FrozenList[int]`` (requires Python 3.9 or newer). + `#172 `_ +- Added support for Python 3.10. + `#227 `_ +- Started shipping platform-specific wheels with the ``musl`` tag targeting typical Alpine Linux runtimes. + `#227 `_ +- Started shipping platform-specific arm64 wheels for Apple Silicon. + `#227 `_ + + +---- + + +1.1.1 (2020-11-14) +================== + +Bugfixes +-------- + +- Provide x86 Windows wheels. + `#169 `_ + + +---- + + +1.1.0 (2020-10-13) +================== + +Features +-------- + +- Add support for hashing of a frozen list. + `#136 `_ + +- Support Python 3.8 and 3.9. + +- Provide wheels for ``aarch64``, ``i686``, ``ppc64le``, ``s390x`` architectures on + Linux as well as ``x86_64``. + + +---- + + +1.0.0 (2019-11-09) +================== + +Deprecations and Removals +------------------------- + +- Dropped support for Python 3.5; only 3.6, 3.7 and 3.8 are supported going forward. + `#24 `_ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/frozenlist-1.5.0.dist-info/RECORD b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/frozenlist-1.5.0.dist-info/RECORD new file mode 100644 index 0000000000000000000000000000000000000000..904627f1bc3bbffdb52033f9b92b1853e69196bb --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/frozenlist-1.5.0.dist-info/RECORD @@ -0,0 +1,12 @@ +frozenlist-1.5.0.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4 +frozenlist-1.5.0.dist-info/LICENSE,sha256=b9UkPpLdf5jsacesN3co50kFcJ_1J6W_mNbQJjwE9bY,11332 +frozenlist-1.5.0.dist-info/METADATA,sha256=BpQvB7z2NbU3f4XTQDvhAZ9L08WR4XiYajilj9IY6Yk,13762 +frozenlist-1.5.0.dist-info/RECORD,, +frozenlist-1.5.0.dist-info/WHEEL,sha256=64hRuO2b8JU2aeheZgbK9oQwal3JVqwtqRhpQNr8ZdQ,224 +frozenlist-1.5.0.dist-info/top_level.txt,sha256=jivtxsPXA3nK3WBWW2LW5Mtu_GHt8UZA13NeCs2cKuA,11 +frozenlist/__init__.py,sha256=ymVtnW3MinO-Ux3cBj_PLEpXnmLawk45el8vcX6IkWY,2371 +frozenlist/__init__.pyi,sha256=vMEoES1xGegPtVXoCi9XydEeHsyuIq-KdeXwP5PdsaA,1470 +frozenlist/__pycache__/__init__.cpython-312.pyc,, +frozenlist/_frozenlist.cpython-312-x86_64-linux-gnu.so,sha256=n65G8t1lqSUcWICd9rjOJujV1lxtniI2JJQQXtc7BjQ,961592 +frozenlist/_frozenlist.pyx,sha256=4YturclNF7wioO7YX3Vzl7Ldb2-iswe6UrjJOMKSswU,2993 +frozenlist/py.typed,sha256=sow9soTwP9T_gEAQSVh7Gb8855h04Nwmhs2We-JRgZM,7 diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/frozenlist-1.5.0.dist-info/WHEEL b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/frozenlist-1.5.0.dist-info/WHEEL new file mode 100644 index 0000000000000000000000000000000000000000..37ea9c3a91170f8ce4fd647cca693c51e36a4bfd --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/frozenlist-1.5.0.dist-info/WHEEL @@ -0,0 +1,8 @@ +Wheel-Version: 1.0 +Generator: setuptools (75.2.0) +Root-Is-Purelib: false +Tag: cp312-cp312-manylinux_2_5_x86_64 +Tag: cp312-cp312-manylinux1_x86_64 +Tag: cp312-cp312-manylinux_2_17_x86_64 +Tag: cp312-cp312-manylinux2014_x86_64 + diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/frozenlist-1.5.0.dist-info/top_level.txt b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/frozenlist-1.5.0.dist-info/top_level.txt new file mode 100644 index 0000000000000000000000000000000000000000..52f13fc459edf8f3def6f792c432f0b64f313176 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/frozenlist-1.5.0.dist-info/top_level.txt @@ -0,0 +1 @@ +frozenlist diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/h11-0.14.0.dist-info/INSTALLER b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/h11-0.14.0.dist-info/INSTALLER new file mode 100644 index 0000000000000000000000000000000000000000..a1b589e38a32041e49332e5e81c2d363dc418d68 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/h11-0.14.0.dist-info/INSTALLER @@ -0,0 +1 @@ +pip diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/h11-0.14.0.dist-info/LICENSE.txt b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/h11-0.14.0.dist-info/LICENSE.txt new file mode 100644 index 0000000000000000000000000000000000000000..8f080eae848f759c9173bfc0c79506357ebe5090 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/h11-0.14.0.dist-info/LICENSE.txt @@ -0,0 +1,22 @@ +The MIT License (MIT) + +Copyright (c) 2016 Nathaniel J. Smith and other contributors + +Permission is hereby granted, free of charge, to any person obtaining +a copy of this software and associated documentation files (the +"Software"), to deal in the Software without restriction, including +without limitation the rights to use, copy, modify, merge, publish, +distribute, sublicense, and/or sell copies of the Software, and to +permit persons to whom the Software is furnished to do so, subject to +the following conditions: + +The above copyright notice and this permission notice shall be +included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE +LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION +WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/h11-0.14.0.dist-info/METADATA b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/h11-0.14.0.dist-info/METADATA new file mode 100644 index 0000000000000000000000000000000000000000..cf12a82f193d8a69b9bc7aaa134cdbb8aa5bd938 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/h11-0.14.0.dist-info/METADATA @@ -0,0 +1,193 @@ +Metadata-Version: 2.1 +Name: h11 +Version: 0.14.0 +Summary: A pure-Python, bring-your-own-I/O implementation of HTTP/1.1 +Home-page: https://github.com/python-hyper/h11 +Author: Nathaniel J. Smith +Author-email: njs@pobox.com +License: MIT +Classifier: Development Status :: 3 - Alpha +Classifier: Intended Audience :: Developers +Classifier: License :: OSI Approved :: MIT License +Classifier: Programming Language :: Python :: Implementation :: CPython +Classifier: Programming Language :: Python :: Implementation :: PyPy +Classifier: Programming Language :: Python :: 3 +Classifier: Programming Language :: Python :: 3 :: Only +Classifier: Programming Language :: Python :: 3.7 +Classifier: Programming Language :: Python :: 3.8 +Classifier: Programming Language :: Python :: 3.9 +Classifier: Programming Language :: Python :: 3.10 +Classifier: Topic :: Internet :: WWW/HTTP +Classifier: Topic :: System :: Networking +Requires-Python: >=3.7 +License-File: LICENSE.txt +Requires-Dist: typing-extensions ; python_version < "3.8" + +h11 +=== + +.. image:: https://travis-ci.org/python-hyper/h11.svg?branch=master + :target: https://travis-ci.org/python-hyper/h11 + :alt: Automated test status + +.. image:: https://codecov.io/gh/python-hyper/h11/branch/master/graph/badge.svg + :target: https://codecov.io/gh/python-hyper/h11 + :alt: Test coverage + +.. image:: https://readthedocs.org/projects/h11/badge/?version=latest + :target: http://h11.readthedocs.io/en/latest/?badge=latest + :alt: Documentation Status + +This is a little HTTP/1.1 library written from scratch in Python, +heavily inspired by `hyper-h2 `_. + +It's a "bring-your-own-I/O" library; h11 contains no IO code +whatsoever. This means you can hook h11 up to your favorite network +API, and that could be anything you want: synchronous, threaded, +asynchronous, or your own implementation of `RFC 6214 +`_ -- h11 won't judge you. +(Compare this to the current state of the art, where every time a `new +network API `_ comes along then someone +gets to start over reimplementing the entire HTTP protocol from +scratch.) Cory Benfield made an `excellent blog post describing the +benefits of this approach +`_, or if you like video +then here's his `PyCon 2016 talk on the same theme +`_. + +This also means that h11 is not immediately useful out of the box: +it's a toolkit for building programs that speak HTTP, not something +that could directly replace ``requests`` or ``twisted.web`` or +whatever. But h11 makes it much easier to implement something like +``requests`` or ``twisted.web``. + +At a high level, working with h11 goes like this: + +1) First, create an ``h11.Connection`` object to track the state of a + single HTTP/1.1 connection. + +2) When you read data off the network, pass it to + ``conn.receive_data(...)``; you'll get back a list of objects + representing high-level HTTP "events". + +3) When you want to send a high-level HTTP event, create the + corresponding "event" object and pass it to ``conn.send(...)``; + this will give you back some bytes that you can then push out + through the network. + +For example, a client might instantiate and then send a +``h11.Request`` object, then zero or more ``h11.Data`` objects for the +request body (e.g., if this is a POST), and then a +``h11.EndOfMessage`` to indicate the end of the message. Then the +server would then send back a ``h11.Response``, some ``h11.Data``, and +its own ``h11.EndOfMessage``. If either side violates the protocol, +you'll get a ``h11.ProtocolError`` exception. + +h11 is suitable for implementing both servers and clients, and has a +pleasantly symmetric API: the events you send as a client are exactly +the ones that you receive as a server and vice-versa. + +`Here's an example of a tiny HTTP client +`_ + +It also has `a fine manual `_. + +FAQ +--- + +*Whyyyyy?* + +I wanted to play with HTTP in `Curio +`__ and `Trio +`__, which at the time didn't have any +HTTP libraries. So I thought, no big deal, Python has, like, a dozen +different implementations of HTTP, surely I can find one that's +reusable. I didn't find one, but I did find Cory's call-to-arms +blog-post. So I figured, well, fine, if I have to implement HTTP from +scratch, at least I can make sure no-one *else* has to ever again. + +*Should I use it?* + +Maybe. You should be aware that it's a very young project. But, it's +feature complete and has an exhaustive test-suite and complete docs, +so the next step is for people to try using it and see how it goes +:-). If you do then please let us know -- if nothing else we'll want +to talk to you before making any incompatible changes! + +*What are the features/limitations?* + +Roughly speaking, it's trying to be a robust, complete, and non-hacky +implementation of the first "chapter" of the HTTP/1.1 spec: `RFC 7230: +HTTP/1.1 Message Syntax and Routing +`_. That is, it mostly focuses on +implementing HTTP at the level of taking bytes on and off the wire, +and the headers related to that, and tries to be anal about spec +conformance. It doesn't know about higher-level concerns like URL +routing, conditional GETs, cross-origin cookie policies, or content +negotiation. But it does know how to take care of framing, +cross-version differences in keep-alive handling, and the "obsolete +line folding" rule, so you can focus your energies on the hard / +interesting parts for your application, and it tries to support the +full specification in the sense that any useful HTTP/1.1 conformant +application should be able to use h11. + +It's pure Python, and has no dependencies outside of the standard +library. + +It has a test suite with 100.0% coverage for both statements and +branches. + +Currently it supports Python 3 (testing on 3.7-3.10) and PyPy 3. +The last Python 2-compatible version was h11 0.11.x. +(Originally it had a Cython wrapper for `http-parser +`_ and a beautiful nested state +machine implemented with ``yield from`` to postprocess the output. But +I had to take these out -- the new *parser* needs fewer lines-of-code +than the old *parser wrapper*, is written in pure Python, uses no +exotic language syntax, and has more features. It's sad, really; that +old state machine was really slick. I just need a few sentences here +to mourn that.) + +I don't know how fast it is. I haven't benchmarked or profiled it yet, +so it's probably got a few pointless hot spots, and I've been trying +to err on the side of simplicity and robustness instead of +micro-optimization. But at the architectural level I tried hard to +avoid fundamentally bad decisions, e.g., I believe that all the +parsing algorithms remain linear-time even in the face of pathological +input like slowloris, and there are no byte-by-byte loops. (I also +believe that it maintains bounded memory usage in the face of +arbitrary/pathological input.) + +The whole library is ~800 lines-of-code. You can read and understand +the whole thing in less than an hour. Most of the energy invested in +this so far has been spent on trying to keep things simple by +minimizing special-cases and ad hoc state manipulation; even though it +is now quite small and simple, I'm still annoyed that I haven't +figured out how to make it even smaller and simpler. (Unfortunately, +HTTP does not lend itself to simplicity.) + +The API is ~feature complete and I don't expect the general outlines +to change much, but you can't judge an API's ergonomics until you +actually document and use it, so I'd expect some changes in the +details. + +*How do I try it?* + +.. code-block:: sh + + $ pip install h11 + $ git clone git@github.com:python-hyper/h11 + $ cd h11/examples + $ python basic-client.py + +and go from there. + +*License?* + +MIT + +*Code of conduct?* + +Contributors are requested to follow our `code of conduct +`_ in +all project spaces. diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/h11-0.14.0.dist-info/RECORD b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/h11-0.14.0.dist-info/RECORD new file mode 100644 index 0000000000000000000000000000000000000000..a63f6ccf5c25065163d0eaffae9999d542cf2fe2 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/h11-0.14.0.dist-info/RECORD @@ -0,0 +1,52 @@ +h11-0.14.0.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4 +h11-0.14.0.dist-info/LICENSE.txt,sha256=N9tbuFkm2yikJ6JYZ_ELEjIAOuob5pzLhRE4rbjm82E,1124 +h11-0.14.0.dist-info/METADATA,sha256=B7pZ0m7WBXNs17vl6hUH9bJTL9s37DaGvY31w7jNxSg,8175 +h11-0.14.0.dist-info/RECORD,, +h11-0.14.0.dist-info/WHEEL,sha256=ewwEueio1C2XeHTvT17n8dZUJgOvyCWCt0WVNLClP9o,92 +h11-0.14.0.dist-info/top_level.txt,sha256=F7dC4jl3zeh8TGHEPaWJrMbeuoWbS379Gwdi-Yvdcis,4 +h11/__init__.py,sha256=iO1KzkSO42yZ6ffg-VMgbx_ZVTWGUY00nRYEWn-s3kY,1507 +h11/__pycache__/__init__.cpython-312.pyc,, +h11/__pycache__/_abnf.cpython-312.pyc,, +h11/__pycache__/_connection.cpython-312.pyc,, +h11/__pycache__/_events.cpython-312.pyc,, +h11/__pycache__/_headers.cpython-312.pyc,, +h11/__pycache__/_readers.cpython-312.pyc,, +h11/__pycache__/_receivebuffer.cpython-312.pyc,, +h11/__pycache__/_state.cpython-312.pyc,, +h11/__pycache__/_util.cpython-312.pyc,, +h11/__pycache__/_version.cpython-312.pyc,, +h11/__pycache__/_writers.cpython-312.pyc,, +h11/_abnf.py,sha256=ybixr0xsupnkA6GFAyMubuXF6Tc1lb_hF890NgCsfNc,4815 +h11/_connection.py,sha256=eS2sorMD0zKLCFiB9lW9W9F_Nzny2tjHa4e6s1ujr1c,26539 +h11/_events.py,sha256=LEfuvg1AbhHaVRwxCd0I-pFn9-ezUOaoL8o2Kvy1PBA,11816 +h11/_headers.py,sha256=RqB8cd8CN0blYPzcLe5qeCh-phv6D1U_CHj4hs67lgQ,10230 +h11/_readers.py,sha256=EbSed0jzwVUiD1nOPAeUcVE4Flf3wXkxfb8c06-OTBM,8383 +h11/_receivebuffer.py,sha256=xrspsdsNgWFxRfQcTXxR8RrdjRXXTK0Io5cQYWpJ1Ws,5252 +h11/_state.py,sha256=k1VL6SDbaPkSrZ-49ewCXDpuiUS69_46YhbWjuV1qEY,13300 +h11/_util.py,sha256=LWkkjXyJaFlAy6Lt39w73UStklFT5ovcvo0TkY7RYuk,4888 +h11/_version.py,sha256=LVyTdiZRzIIEv79UyOgbM5iUrJUllEzlCWaJEYBY1zc,686 +h11/_writers.py,sha256=oFKm6PtjeHfbj4RLX7VB7KDc1gIY53gXG3_HR9ltmTA,5081 +h11/py.typed,sha256=sow9soTwP9T_gEAQSVh7Gb8855h04Nwmhs2We-JRgZM,7 +h11/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +h11/tests/__pycache__/__init__.cpython-312.pyc,, +h11/tests/__pycache__/helpers.cpython-312.pyc,, +h11/tests/__pycache__/test_against_stdlib_http.cpython-312.pyc,, +h11/tests/__pycache__/test_connection.cpython-312.pyc,, +h11/tests/__pycache__/test_events.cpython-312.pyc,, +h11/tests/__pycache__/test_headers.cpython-312.pyc,, +h11/tests/__pycache__/test_helpers.cpython-312.pyc,, +h11/tests/__pycache__/test_io.cpython-312.pyc,, +h11/tests/__pycache__/test_receivebuffer.cpython-312.pyc,, +h11/tests/__pycache__/test_state.cpython-312.pyc,, +h11/tests/__pycache__/test_util.cpython-312.pyc,, +h11/tests/data/test-file,sha256=ZJ03Rqs98oJw29OHzJg7LlMzyGQaRAY0r3AqBeM2wVU,65 +h11/tests/helpers.py,sha256=a1EVG_p7xU4wRsa3tMPTRxuaKCmretok9sxXWvqfmQA,3355 +h11/tests/test_against_stdlib_http.py,sha256=cojCHgHXFQ8gWhNlEEwl3trmOpN-5uDukRoHnElqo3A,3995 +h11/tests/test_connection.py,sha256=ZbPLDPclKvjgjAhgk-WlCPBaf17c4XUIV2tpaW08jOI,38720 +h11/tests/test_events.py,sha256=LPVLbcV-NvPNK9fW3rraR6Bdpz1hAlsWubMtNaJ5gHg,4657 +h11/tests/test_headers.py,sha256=qd8T1Zenuz5GbD6wklSJ5G8VS7trrYgMV0jT-SMvqg8,5612 +h11/tests/test_helpers.py,sha256=kAo0CEM4LGqmyyP2ZFmhsyq3UFJqoFfAbzu3hbWreRM,794 +h11/tests/test_io.py,sha256=uCZVnjarkRBkudfC1ij-KSCQ71XWJhnkgkgWWkKgYPQ,16386 +h11/tests/test_receivebuffer.py,sha256=3jGbeJM36Akqg_pAhPb7XzIn2NS6RhPg-Ryg8Eu6ytk,3454 +h11/tests/test_state.py,sha256=rqll9WqFsJPE0zSrtCn9LH659mPKsDeXZ-DwXwleuBQ,8928 +h11/tests/test_util.py,sha256=VO5L4nSFe4pgtSwKuv6u_6l0H7UeizF5WKuHTWreg70,2970 diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/h11-0.14.0.dist-info/WHEEL b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/h11-0.14.0.dist-info/WHEEL new file mode 100644 index 0000000000000000000000000000000000000000..5bad85fdc1cd08553756d0fb2c7be8b5ad6af7fb --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/h11-0.14.0.dist-info/WHEEL @@ -0,0 +1,5 @@ +Wheel-Version: 1.0 +Generator: bdist_wheel (0.37.0) +Root-Is-Purelib: true +Tag: py3-none-any + diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/h11-0.14.0.dist-info/top_level.txt b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/h11-0.14.0.dist-info/top_level.txt new file mode 100644 index 0000000000000000000000000000000000000000..0d24def711344ec6f4da2108f7d5c9261eb35f8b --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/h11-0.14.0.dist-info/top_level.txt @@ -0,0 +1 @@ +h11 diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/multidict-6.1.0.dist-info/INSTALLER b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/multidict-6.1.0.dist-info/INSTALLER new file mode 100644 index 0000000000000000000000000000000000000000..a1b589e38a32041e49332e5e81c2d363dc418d68 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/multidict-6.1.0.dist-info/INSTALLER @@ -0,0 +1 @@ +pip diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/multidict-6.1.0.dist-info/LICENSE b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/multidict-6.1.0.dist-info/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..8727172ae058e56805bd8ed0f988b6788711dcfd --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/multidict-6.1.0.dist-info/LICENSE @@ -0,0 +1,13 @@ + Copyright 2016 Andrew Svetlov and aio-libs contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/multidict-6.1.0.dist-info/METADATA b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/multidict-6.1.0.dist-info/METADATA new file mode 100644 index 0000000000000000000000000000000000000000..93f85177b97ec6be66b1ed74fc74cac756d1f72f --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/multidict-6.1.0.dist-info/METADATA @@ -0,0 +1,140 @@ +Metadata-Version: 2.1 +Name: multidict +Version: 6.1.0 +Summary: multidict implementation +Home-page: https://github.com/aio-libs/multidict +Author: Andrew Svetlov +Author-email: andrew.svetlov@gmail.com +License: Apache 2 +Project-URL: Chat: Matrix, https://matrix.to/#/#aio-libs:matrix.org +Project-URL: Chat: Matrix Space, https://matrix.to/#/#aio-libs-space:matrix.org +Project-URL: CI: GitHub, https://github.com/aio-libs/multidict/actions +Project-URL: Code of Conduct, https://github.com/aio-libs/.github/blob/master/CODE_OF_CONDUCT.md +Project-URL: Coverage: codecov, https://codecov.io/github/aio-libs/multidict +Project-URL: Docs: Changelog, https://multidict.aio-libs.org/en/latest/changes/ +Project-URL: Docs: RTD, https://multidict.aio-libs.org +Project-URL: GitHub: issues, https://github.com/aio-libs/multidict/issues +Project-URL: GitHub: repo, https://github.com/aio-libs/multidict +Classifier: Development Status :: 5 - Production/Stable +Classifier: Intended Audience :: Developers +Classifier: License :: OSI Approved :: Apache Software License +Classifier: Programming Language :: Python +Classifier: Programming Language :: Python :: 3 +Classifier: Programming Language :: Python :: 3.8 +Classifier: Programming Language :: Python :: 3.9 +Classifier: Programming Language :: Python :: 3.10 +Classifier: Programming Language :: Python :: 3.11 +Classifier: Programming Language :: Python :: 3.12 +Classifier: Programming Language :: Python :: 3.13 +Requires-Python: >=3.8 +Description-Content-Type: text/x-rst +License-File: LICENSE +Requires-Dist: typing-extensions >=4.1.0 ; python_version < "3.11" + +========= +multidict +========= + +.. image:: https://github.com/aio-libs/multidict/actions/workflows/ci-cd.yml/badge.svg + :target: https://github.com/aio-libs/multidict/actions + :alt: GitHub status for master branch + +.. image:: https://codecov.io/gh/aio-libs/multidict/branch/master/graph/badge.svg + :target: https://codecov.io/gh/aio-libs/multidict + :alt: Coverage metrics + +.. image:: https://img.shields.io/pypi/v/multidict.svg + :target: https://pypi.org/project/multidict + :alt: PyPI + +.. image:: https://readthedocs.org/projects/multidict/badge/?version=latest + :target: https://multidict.aio-libs.org + :alt: Read The Docs build status badge + +.. image:: https://img.shields.io/pypi/pyversions/multidict.svg + :target: https://pypi.org/project/multidict + :alt: Python versions + +.. image:: https://img.shields.io/matrix/aio-libs:matrix.org?label=Discuss%20on%20Matrix%20at%20%23aio-libs%3Amatrix.org&logo=matrix&server_fqdn=matrix.org&style=flat + :target: https://matrix.to/#/%23aio-libs:matrix.org + :alt: Matrix Room — #aio-libs:matrix.org + +.. image:: https://img.shields.io/matrix/aio-libs-space:matrix.org?label=Discuss%20on%20Matrix%20at%20%23aio-libs-space%3Amatrix.org&logo=matrix&server_fqdn=matrix.org&style=flat + :target: https://matrix.to/#/%23aio-libs-space:matrix.org + :alt: Matrix Space — #aio-libs-space:matrix.org + +Multidict is dict-like collection of *key-value pairs* where key +might occur more than once in the container. + +Introduction +------------ + +*HTTP Headers* and *URL query string* require specific data structure: +*multidict*. It behaves mostly like a regular ``dict`` but it may have +several *values* for the same *key* and *preserves insertion ordering*. + +The *key* is ``str`` (or ``istr`` for case-insensitive dictionaries). + +``multidict`` has four multidict classes: +``MultiDict``, ``MultiDictProxy``, ``CIMultiDict`` +and ``CIMultiDictProxy``. + +Immutable proxies (``MultiDictProxy`` and +``CIMultiDictProxy``) provide a dynamic view for the +proxied multidict, the view reflects underlying collection changes. They +implement the ``collections.abc.Mapping`` interface. + +Regular mutable (``MultiDict`` and ``CIMultiDict``) classes +implement ``collections.abc.MutableMapping`` and allows them to change +their own content. + + +*Case insensitive* (``CIMultiDict`` and +``CIMultiDictProxy``) assume the *keys* are case +insensitive, e.g.:: + + >>> dct = CIMultiDict(key='val') + >>> 'Key' in dct + True + >>> dct['Key'] + 'val' + +*Keys* should be ``str`` or ``istr`` instances. + +The library has optional C Extensions for speed. + + +License +------- + +Apache 2 + +Library Installation +-------------------- + +.. code-block:: bash + + $ pip install multidict + +The library is Python 3 only! + +PyPI contains binary wheels for Linux, Windows and MacOS. If you want to install +``multidict`` on another operating system (or *Alpine Linux* inside a Docker) the +tarball will be used to compile the library from source. It requires a C compiler and +Python headers to be installed. + +To skip the compilation, please use the `MULTIDICT_NO_EXTENSIONS` environment variable, +e.g.: + +.. code-block:: bash + + $ MULTIDICT_NO_EXTENSIONS=1 pip install multidict + +Please note, the pure Python (uncompiled) version is about 20-50 times slower depending on +the usage scenario!!! + + + +Changelog +--------- +See `RTD page `_. diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/multidict-6.1.0.dist-info/RECORD b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/multidict-6.1.0.dist-info/RECORD new file mode 100644 index 0000000000000000000000000000000000000000..714c904bf84a5c555bc2db0239298b49b34cdcb9 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/multidict-6.1.0.dist-info/RECORD @@ -0,0 +1,19 @@ +multidict-6.1.0.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4 +multidict-6.1.0.dist-info/LICENSE,sha256=k9Ealo4vDzY3PECBH_bSDhc_WMPKtYhM1mF7v9eVSSo,611 +multidict-6.1.0.dist-info/METADATA,sha256=OnCx5DR4XPf64GIDK4XmcA2e7HLQ_784vMfEQy287kM,4979 +multidict-6.1.0.dist-info/RECORD,, +multidict-6.1.0.dist-info/WHEEL,sha256=3FRagTIevYnyede1Gym_XNKguJrd07UOyEdLNhxNq20,151 +multidict-6.1.0.dist-info/top_level.txt,sha256=-euDElkk5_qkmfIJ7WiqCab02ZlSFZWynejKg59qZQQ,10 +multidict/__init__.py,sha256=p60Ag5UVACSli1txazSi85foCmHN-cg3qZDCuWdOKng,928 +multidict/__init__.pyi,sha256=SbgC2ew1NvNXWlRKs9o0KhW4moozgMqgQ0OA4Re5JQQ,4840 +multidict/__pycache__/__init__.cpython-312.pyc,, +multidict/__pycache__/_abc.cpython-312.pyc,, +multidict/__pycache__/_compat.cpython-312.pyc,, +multidict/__pycache__/_multidict_base.cpython-312.pyc,, +multidict/__pycache__/_multidict_py.cpython-312.pyc,, +multidict/_abc.py,sha256=Zvnrn4SBkrv4QTD7-ZzqNcoxw0f8KStLMPzGvBuGT2w,1190 +multidict/_compat.py,sha256=uCNUpVHJSFOiKUJmRcz3SDqMpkb37C_csc29ijr8Evo,352 +multidict/_multidict.cpython-312-x86_64-linux-gnu.so,sha256=6BwP62oLns2chEgPfwAa8DseIoF0wOWBe81pHjnlqhs,418968 +multidict/_multidict_base.py,sha256=ZndtnZ5oc1sODKmXsv6F9kWvVNCda9xAEEFXkaPoFoA,3979 +multidict/_multidict_py.py,sha256=57h4sYrRIu7EjMX4YpHVIZVrV9-q1KCW3F6rao10D3U,15050 +multidict/py.typed,sha256=e9bmbH3UFxsabQrnNFPG9qxIXztwbcM6IKDYnvZwprY,15 diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/multidict-6.1.0.dist-info/WHEEL b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/multidict-6.1.0.dist-info/WHEEL new file mode 100644 index 0000000000000000000000000000000000000000..c0c84dc2559cf7b954eb2b77e5a5b13eef99e6ac --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/multidict-6.1.0.dist-info/WHEEL @@ -0,0 +1,6 @@ +Wheel-Version: 1.0 +Generator: setuptools (74.1.2) +Root-Is-Purelib: false +Tag: cp312-cp312-manylinux_2_17_x86_64 +Tag: cp312-cp312-manylinux2014_x86_64 + diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/multidict-6.1.0.dist-info/top_level.txt b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/multidict-6.1.0.dist-info/top_level.txt new file mode 100644 index 0000000000000000000000000000000000000000..afcecdff08229f3faf1ecef41cf814c26c207f5c --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/multidict-6.1.0.dist-info/top_level.txt @@ -0,0 +1 @@ +multidict diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/numpy/_globals.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/numpy/_globals.py new file mode 100644 index 0000000000000000000000000000000000000000..416a20f5e11b14b1da34e2bfb45c7961edc9097c --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/numpy/_globals.py @@ -0,0 +1,95 @@ +""" +Module defining global singleton classes. + +This module raises a RuntimeError if an attempt to reload it is made. In that +way the identities of the classes defined here are fixed and will remain so +even if numpy itself is reloaded. In particular, a function like the following +will still work correctly after numpy is reloaded:: + + def foo(arg=np._NoValue): + if arg is np._NoValue: + ... + +That was not the case when the singleton classes were defined in the numpy +``__init__.py`` file. See gh-7844 for a discussion of the reload problem that +motivated this module. + +""" +import enum + +from ._utils import set_module as _set_module + +__all__ = ['_NoValue', '_CopyMode'] + + +# Disallow reloading this module so as to preserve the identities of the +# classes defined here. +if '_is_loaded' in globals(): + raise RuntimeError('Reloading numpy._globals is not allowed') +_is_loaded = True + + +class _NoValueType: + """Special keyword value. + + The instance of this class may be used as the default value assigned to a + keyword if no other obvious default (e.g., `None`) is suitable, + + Common reasons for using this keyword are: + + - A new keyword is added to a function, and that function forwards its + inputs to another function or method which can be defined outside of + NumPy. For example, ``np.std(x)`` calls ``x.std``, so when a ``keepdims`` + keyword was added that could only be forwarded if the user explicitly + specified ``keepdims``; downstream array libraries may not have added + the same keyword, so adding ``x.std(..., keepdims=keepdims)`` + unconditionally could have broken previously working code. + - A keyword is being deprecated, and a deprecation warning must only be + emitted when the keyword is used. + + """ + __instance = None + def __new__(cls): + # ensure that only one instance exists + if not cls.__instance: + cls.__instance = super().__new__(cls) + return cls.__instance + + def __repr__(self): + return "" + + +_NoValue = _NoValueType() + + +@_set_module("numpy") +class _CopyMode(enum.Enum): + """ + An enumeration for the copy modes supported + by numpy.copy() and numpy.array(). The following three modes are supported, + + - ALWAYS: This means that a deep copy of the input + array will always be taken. + - IF_NEEDED: This means that a deep copy of the input + array will be taken only if necessary. + - NEVER: This means that the deep copy will never be taken. + If a copy cannot be avoided then a `ValueError` will be + raised. + + Note that the buffer-protocol could in theory do copies. NumPy currently + assumes an object exporting the buffer protocol will never do this. + """ + + ALWAYS = True + IF_NEEDED = False + NEVER = 2 + + def __bool__(self): + # For backwards compatibility + if self == _CopyMode.ALWAYS: + return True + + if self == _CopyMode.IF_NEEDED: + return False + + raise ValueError(f"{self} is neither True nor False.") diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/numpy/_pytesttester.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/numpy/_pytesttester.py new file mode 100644 index 0000000000000000000000000000000000000000..1c38291ae3319a08bb665fe5c86dfa13e1655a4c --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/numpy/_pytesttester.py @@ -0,0 +1,207 @@ +""" +Pytest test running. + +This module implements the ``test()`` function for NumPy modules. The usual +boiler plate for doing that is to put the following in the module +``__init__.py`` file:: + + from numpy._pytesttester import PytestTester + test = PytestTester(__name__) + del PytestTester + + +Warnings filtering and other runtime settings should be dealt with in the +``pytest.ini`` file in the numpy repo root. The behavior of the test depends on +whether or not that file is found as follows: + +* ``pytest.ini`` is present (develop mode) + All warnings except those explicitly filtered out are raised as error. +* ``pytest.ini`` is absent (release mode) + DeprecationWarnings and PendingDeprecationWarnings are ignored, other + warnings are passed through. + +In practice, tests run from the numpy repo are run in develop mode. That +includes the standard ``python runtests.py`` invocation. + +This module is imported by every numpy subpackage, so lies at the top level to +simplify circular import issues. For the same reason, it contains no numpy +imports at module scope, instead importing numpy within function calls. +""" +import sys +import os + +__all__ = ['PytestTester'] + + +def _show_numpy_info(): + import numpy as np + + print("NumPy version %s" % np.__version__) + relaxed_strides = np.ones((10, 1), order="C").flags.f_contiguous + print("NumPy relaxed strides checking option:", relaxed_strides) + info = np.lib.utils._opt_info() + print("NumPy CPU features: ", (info if info else 'nothing enabled')) + + +class PytestTester: + """ + Pytest test runner. + + A test function is typically added to a package's __init__.py like so:: + + from numpy._pytesttester import PytestTester + test = PytestTester(__name__).test + del PytestTester + + Calling this test function finds and runs all tests associated with the + module and all its sub-modules. + + Attributes + ---------- + module_name : str + Full path to the package to test. + + Parameters + ---------- + module_name : module name + The name of the module to test. + + Notes + ----- + Unlike the previous ``nose``-based implementation, this class is not + publicly exposed as it performs some ``numpy``-specific warning + suppression. + + """ + def __init__(self, module_name): + self.module_name = module_name + + def __call__(self, label='fast', verbose=1, extra_argv=None, + doctests=False, coverage=False, durations=-1, tests=None): + """ + Run tests for module using pytest. + + Parameters + ---------- + label : {'fast', 'full'}, optional + Identifies the tests to run. When set to 'fast', tests decorated + with `pytest.mark.slow` are skipped, when 'full', the slow marker + is ignored. + verbose : int, optional + Verbosity value for test outputs, in the range 1-3. Default is 1. + extra_argv : list, optional + List with any extra arguments to pass to pytests. + doctests : bool, optional + .. note:: Not supported + coverage : bool, optional + If True, report coverage of NumPy code. Default is False. + Requires installation of (pip) pytest-cov. + durations : int, optional + If < 0, do nothing, If 0, report time of all tests, if > 0, + report the time of the slowest `timer` tests. Default is -1. + tests : test or list of tests + Tests to be executed with pytest '--pyargs' + + Returns + ------- + result : bool + Return True on success, false otherwise. + + Notes + ----- + Each NumPy module exposes `test` in its namespace to run all tests for + it. For example, to run all tests for numpy.lib: + + >>> np.lib.test() #doctest: +SKIP + + Examples + -------- + >>> result = np.lib.test() #doctest: +SKIP + ... + 1023 passed, 2 skipped, 6 deselected, 1 xfailed in 10.39 seconds + >>> result + True + + """ + import pytest + import warnings + + module = sys.modules[self.module_name] + module_path = os.path.abspath(module.__path__[0]) + + # setup the pytest arguments + pytest_args = ["-l"] + + # offset verbosity. The "-q" cancels a "-v". + pytest_args += ["-q"] + + if sys.version_info < (3, 12): + with warnings.catch_warnings(): + warnings.simplefilter("always") + # Filter out distutils cpu warnings (could be localized to + # distutils tests). ASV has problems with top level import, + # so fetch module for suppression here. + from numpy.distutils import cpuinfo + + with warnings.catch_warnings(record=True): + # Ignore the warning from importing the array_api submodule. This + # warning is done on import, so it would break pytest collection, + # but importing it early here prevents the warning from being + # issued when it imported again. + import numpy.array_api + + # Filter out annoying import messages. Want these in both develop and + # release mode. + pytest_args += [ + "-W ignore:Not importing directory", + "-W ignore:numpy.dtype size changed", + "-W ignore:numpy.ufunc size changed", + "-W ignore::UserWarning:cpuinfo", + ] + + # When testing matrices, ignore their PendingDeprecationWarnings + pytest_args += [ + "-W ignore:the matrix subclass is not", + "-W ignore:Importing from numpy.matlib is", + ] + + if doctests: + pytest_args += ["--doctest-modules"] + + if extra_argv: + pytest_args += list(extra_argv) + + if verbose > 1: + pytest_args += ["-" + "v"*(verbose - 1)] + + if coverage: + pytest_args += ["--cov=" + module_path] + + if label == "fast": + # not importing at the top level to avoid circular import of module + from numpy.testing import IS_PYPY + if IS_PYPY: + pytest_args += ["-m", "not slow and not slow_pypy"] + else: + pytest_args += ["-m", "not slow"] + + elif label != "full": + pytest_args += ["-m", label] + + if durations >= 0: + pytest_args += ["--durations=%s" % durations] + + if tests is None: + tests = [self.module_name] + + pytest_args += ["--pyargs"] + list(tests) + + # run tests. + _show_numpy_info() + + try: + code = pytest.main(pytest_args) + except SystemExit as exc: + code = exc.code + + return code == 0 diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/numpy/_pytesttester.pyi b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/numpy/_pytesttester.pyi new file mode 100644 index 0000000000000000000000000000000000000000..67ac87b33de164c710a25110d45545e24a06d42e --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/numpy/_pytesttester.pyi @@ -0,0 +1,18 @@ +from collections.abc import Iterable +from typing import Literal as L + +__all__: list[str] + +class PytestTester: + module_name: str + def __init__(self, module_name: str) -> None: ... + def __call__( + self, + label: L["fast", "full"] = ..., + verbose: int = ..., + extra_argv: None | Iterable[str] = ..., + doctests: L[False] = ..., + coverage: bool = ..., + durations: int = ..., + tests: None | Iterable[str] = ..., + ) -> bool: ... diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/numpy/exceptions.pyi b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/numpy/exceptions.pyi new file mode 100644 index 0000000000000000000000000000000000000000..c76a0946b97b088c9f0c431eb559b5a3c86a4f6b --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/numpy/exceptions.pyi @@ -0,0 +1,18 @@ +from typing import overload + +__all__: list[str] + +class ComplexWarning(RuntimeWarning): ... +class ModuleDeprecationWarning(DeprecationWarning): ... +class VisibleDeprecationWarning(UserWarning): ... +class TooHardError(RuntimeError): ... +class DTypePromotionError(TypeError): ... + +class AxisError(ValueError, IndexError): + axis: None | int + ndim: None | int + @overload + def __init__(self, axis: str, ndim: None = ..., msg_prefix: None = ...) -> None: ... + @overload + def __init__(self, axis: int, ndim: int, msg_prefix: None | str = ...) -> None: ... + def __str__(self) -> str: ... diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/numpy/matlib.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/numpy/matlib.py new file mode 100644 index 0000000000000000000000000000000000000000..e929fd9b1885f208afb6301f19cc21511adc098b --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/numpy/matlib.py @@ -0,0 +1,378 @@ +import warnings + +# 2018-05-29, PendingDeprecationWarning added to matrix.__new__ +# 2020-01-23, numpy 1.19.0 PendingDeprecatonWarning +warnings.warn("Importing from numpy.matlib is deprecated since 1.19.0. " + "The matrix subclass is not the recommended way to represent " + "matrices or deal with linear algebra (see " + "https://docs.scipy.org/doc/numpy/user/numpy-for-matlab-users.html). " + "Please adjust your code to use regular ndarray. ", + PendingDeprecationWarning, stacklevel=2) + +import numpy as np +from numpy.matrixlib.defmatrix import matrix, asmatrix +# Matlib.py contains all functions in the numpy namespace with a few +# replacements. See doc/source/reference/routines.matlib.rst for details. +# Need * as we're copying the numpy namespace. +from numpy import * # noqa: F403 + +__version__ = np.__version__ + +__all__ = np.__all__[:] # copy numpy namespace +__all__ += ['rand', 'randn', 'repmat'] + +def empty(shape, dtype=None, order='C'): + """Return a new matrix of given shape and type, without initializing entries. + + Parameters + ---------- + shape : int or tuple of int + Shape of the empty matrix. + dtype : data-type, optional + Desired output data-type. + order : {'C', 'F'}, optional + Whether to store multi-dimensional data in row-major + (C-style) or column-major (Fortran-style) order in + memory. + + See Also + -------- + empty_like, zeros + + Notes + ----- + `empty`, unlike `zeros`, does not set the matrix values to zero, + and may therefore be marginally faster. On the other hand, it requires + the user to manually set all the values in the array, and should be + used with caution. + + Examples + -------- + >>> import numpy.matlib + >>> np.matlib.empty((2, 2)) # filled with random data + matrix([[ 6.76425276e-320, 9.79033856e-307], # random + [ 7.39337286e-309, 3.22135945e-309]]) + >>> np.matlib.empty((2, 2), dtype=int) + matrix([[ 6600475, 0], # random + [ 6586976, 22740995]]) + + """ + return ndarray.__new__(matrix, shape, dtype, order=order) + +def ones(shape, dtype=None, order='C'): + """ + Matrix of ones. + + Return a matrix of given shape and type, filled with ones. + + Parameters + ---------- + shape : {sequence of ints, int} + Shape of the matrix + dtype : data-type, optional + The desired data-type for the matrix, default is np.float64. + order : {'C', 'F'}, optional + Whether to store matrix in C- or Fortran-contiguous order, + default is 'C'. + + Returns + ------- + out : matrix + Matrix of ones of given shape, dtype, and order. + + See Also + -------- + ones : Array of ones. + matlib.zeros : Zero matrix. + + Notes + ----- + If `shape` has length one i.e. ``(N,)``, or is a scalar ``N``, + `out` becomes a single row matrix of shape ``(1,N)``. + + Examples + -------- + >>> np.matlib.ones((2,3)) + matrix([[1., 1., 1.], + [1., 1., 1.]]) + + >>> np.matlib.ones(2) + matrix([[1., 1.]]) + + """ + a = ndarray.__new__(matrix, shape, dtype, order=order) + a.fill(1) + return a + +def zeros(shape, dtype=None, order='C'): + """ + Return a matrix of given shape and type, filled with zeros. + + Parameters + ---------- + shape : int or sequence of ints + Shape of the matrix + dtype : data-type, optional + The desired data-type for the matrix, default is float. + order : {'C', 'F'}, optional + Whether to store the result in C- or Fortran-contiguous order, + default is 'C'. + + Returns + ------- + out : matrix + Zero matrix of given shape, dtype, and order. + + See Also + -------- + numpy.zeros : Equivalent array function. + matlib.ones : Return a matrix of ones. + + Notes + ----- + If `shape` has length one i.e. ``(N,)``, or is a scalar ``N``, + `out` becomes a single row matrix of shape ``(1,N)``. + + Examples + -------- + >>> import numpy.matlib + >>> np.matlib.zeros((2, 3)) + matrix([[0., 0., 0.], + [0., 0., 0.]]) + + >>> np.matlib.zeros(2) + matrix([[0., 0.]]) + + """ + a = ndarray.__new__(matrix, shape, dtype, order=order) + a.fill(0) + return a + +def identity(n,dtype=None): + """ + Returns the square identity matrix of given size. + + Parameters + ---------- + n : int + Size of the returned identity matrix. + dtype : data-type, optional + Data-type of the output. Defaults to ``float``. + + Returns + ------- + out : matrix + `n` x `n` matrix with its main diagonal set to one, + and all other elements zero. + + See Also + -------- + numpy.identity : Equivalent array function. + matlib.eye : More general matrix identity function. + + Examples + -------- + >>> import numpy.matlib + >>> np.matlib.identity(3, dtype=int) + matrix([[1, 0, 0], + [0, 1, 0], + [0, 0, 1]]) + + """ + a = array([1]+n*[0], dtype=dtype) + b = empty((n, n), dtype=dtype) + b.flat = a + return b + +def eye(n,M=None, k=0, dtype=float, order='C'): + """ + Return a matrix with ones on the diagonal and zeros elsewhere. + + Parameters + ---------- + n : int + Number of rows in the output. + M : int, optional + Number of columns in the output, defaults to `n`. + k : int, optional + Index of the diagonal: 0 refers to the main diagonal, + a positive value refers to an upper diagonal, + and a negative value to a lower diagonal. + dtype : dtype, optional + Data-type of the returned matrix. + order : {'C', 'F'}, optional + Whether the output should be stored in row-major (C-style) or + column-major (Fortran-style) order in memory. + + .. versionadded:: 1.14.0 + + Returns + ------- + I : matrix + A `n` x `M` matrix where all elements are equal to zero, + except for the `k`-th diagonal, whose values are equal to one. + + See Also + -------- + numpy.eye : Equivalent array function. + identity : Square identity matrix. + + Examples + -------- + >>> import numpy.matlib + >>> np.matlib.eye(3, k=1, dtype=float) + matrix([[0., 1., 0.], + [0., 0., 1.], + [0., 0., 0.]]) + + """ + return asmatrix(np.eye(n, M=M, k=k, dtype=dtype, order=order)) + +def rand(*args): + """ + Return a matrix of random values with given shape. + + Create a matrix of the given shape and propagate it with + random samples from a uniform distribution over ``[0, 1)``. + + Parameters + ---------- + \\*args : Arguments + Shape of the output. + If given as N integers, each integer specifies the size of one + dimension. + If given as a tuple, this tuple gives the complete shape. + + Returns + ------- + out : ndarray + The matrix of random values with shape given by `\\*args`. + + See Also + -------- + randn, numpy.random.RandomState.rand + + Examples + -------- + >>> np.random.seed(123) + >>> import numpy.matlib + >>> np.matlib.rand(2, 3) + matrix([[0.69646919, 0.28613933, 0.22685145], + [0.55131477, 0.71946897, 0.42310646]]) + >>> np.matlib.rand((2, 3)) + matrix([[0.9807642 , 0.68482974, 0.4809319 ], + [0.39211752, 0.34317802, 0.72904971]]) + + If the first argument is a tuple, other arguments are ignored: + + >>> np.matlib.rand((2, 3), 4) + matrix([[0.43857224, 0.0596779 , 0.39804426], + [0.73799541, 0.18249173, 0.17545176]]) + + """ + if isinstance(args[0], tuple): + args = args[0] + return asmatrix(np.random.rand(*args)) + +def randn(*args): + """ + Return a random matrix with data from the "standard normal" distribution. + + `randn` generates a matrix filled with random floats sampled from a + univariate "normal" (Gaussian) distribution of mean 0 and variance 1. + + Parameters + ---------- + \\*args : Arguments + Shape of the output. + If given as N integers, each integer specifies the size of one + dimension. If given as a tuple, this tuple gives the complete shape. + + Returns + ------- + Z : matrix of floats + A matrix of floating-point samples drawn from the standard normal + distribution. + + See Also + -------- + rand, numpy.random.RandomState.randn + + Notes + ----- + For random samples from the normal distribution with mean ``mu`` and + standard deviation ``sigma``, use:: + + sigma * np.matlib.randn(...) + mu + + Examples + -------- + >>> np.random.seed(123) + >>> import numpy.matlib + >>> np.matlib.randn(1) + matrix([[-1.0856306]]) + >>> np.matlib.randn(1, 2, 3) + matrix([[ 0.99734545, 0.2829785 , -1.50629471], + [-0.57860025, 1.65143654, -2.42667924]]) + + Two-by-four matrix of samples from the normal distribution with + mean 3 and standard deviation 2.5: + + >>> 2.5 * np.matlib.randn((2, 4)) + 3 + matrix([[1.92771843, 6.16484065, 0.83314899, 1.30278462], + [2.76322758, 6.72847407, 1.40274501, 1.8900451 ]]) + + """ + if isinstance(args[0], tuple): + args = args[0] + return asmatrix(np.random.randn(*args)) + +def repmat(a, m, n): + """ + Repeat a 0-D to 2-D array or matrix MxN times. + + Parameters + ---------- + a : array_like + The array or matrix to be repeated. + m, n : int + The number of times `a` is repeated along the first and second axes. + + Returns + ------- + out : ndarray + The result of repeating `a`. + + Examples + -------- + >>> import numpy.matlib + >>> a0 = np.array(1) + >>> np.matlib.repmat(a0, 2, 3) + array([[1, 1, 1], + [1, 1, 1]]) + + >>> a1 = np.arange(4) + >>> np.matlib.repmat(a1, 2, 2) + array([[0, 1, 2, 3, 0, 1, 2, 3], + [0, 1, 2, 3, 0, 1, 2, 3]]) + + >>> a2 = np.asmatrix(np.arange(6).reshape(2, 3)) + >>> np.matlib.repmat(a2, 2, 3) + matrix([[0, 1, 2, 0, 1, 2, 0, 1, 2], + [3, 4, 5, 3, 4, 5, 3, 4, 5], + [0, 1, 2, 0, 1, 2, 0, 1, 2], + [3, 4, 5, 3, 4, 5, 3, 4, 5]]) + + """ + a = asanyarray(a) + ndim = a.ndim + if ndim == 0: + origrows, origcols = (1, 1) + elif ndim == 1: + origrows, origcols = (1, a.shape[0]) + else: + origrows, origcols = a.shape + rows = origrows * m + cols = origcols * n + c = a.reshape(1, a.size).repeat(m, 0).reshape(rows, origcols).repeat(n, 0) + return c.reshape(rows, cols) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/numpy/py.typed b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/numpy/py.typed new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/nvidia_cublas_cu12-12.8.4.1.dist-info/INSTALLER b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/nvidia_cublas_cu12-12.8.4.1.dist-info/INSTALLER new file mode 100644 index 0000000000000000000000000000000000000000..a1b589e38a32041e49332e5e81c2d363dc418d68 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/nvidia_cublas_cu12-12.8.4.1.dist-info/INSTALLER @@ -0,0 +1 @@ +pip diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/nvidia_cublas_cu12-12.8.4.1.dist-info/License.txt b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/nvidia_cublas_cu12-12.8.4.1.dist-info/License.txt new file mode 100644 index 0000000000000000000000000000000000000000..b491c70e0aef319022ded661e111ddbd45b8a17f --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/nvidia_cublas_cu12-12.8.4.1.dist-info/License.txt @@ -0,0 +1,1568 @@ +End User License Agreement +-------------------------- + + +Preface +------- + +The Software License Agreement in Chapter 1 and the Supplement +in Chapter 2 contain license terms and conditions that govern +the use of NVIDIA software. By accepting this agreement, you +agree to comply with all the terms and conditions applicable +to the product(s) included herein. + + +NVIDIA Driver + + +Description + +This package contains the operating system driver and +fundamental system software components for NVIDIA GPUs. + + +NVIDIA CUDA Toolkit + + +Description + +The NVIDIA CUDA Toolkit provides command-line and graphical +tools for building, debugging and optimizing the performance +of applications accelerated by NVIDIA GPUs, runtime and math +libraries, and documentation including programming guides, +user manuals, and API references. + + +Default Install Location of CUDA Toolkit + +Windows platform: + +%ProgramFiles%\NVIDIA GPU Computing Toolkit\CUDA\v#.# + +Linux platform: + +/usr/local/cuda-#.# + +Mac platform: + +/Developer/NVIDIA/CUDA-#.# + + +NVIDIA CUDA Samples + + +Description + +This package includes over 100+ CUDA examples that demonstrate +various CUDA programming principles, and efficient CUDA +implementation of algorithms in specific application domains. + + +Default Install Location of CUDA Samples + +Windows platform: + +%ProgramData%\NVIDIA Corporation\CUDA Samples\v#.# + +Linux platform: + +/usr/local/cuda-#.#/samples + +and + +$HOME/NVIDIA_CUDA-#.#_Samples + +Mac platform: + +/Developer/NVIDIA/CUDA-#.#/samples + + +NVIDIA Nsight Visual Studio Edition (Windows only) + + +Description + +NVIDIA Nsight Development Platform, Visual Studio Edition is a +development environment integrated into Microsoft Visual +Studio that provides tools for debugging, profiling, analyzing +and optimizing your GPU computing and graphics applications. + + +Default Install Location of Nsight Visual Studio Edition + +Windows platform: + +%ProgramFiles(x86)%\NVIDIA Corporation\Nsight Visual Studio Edition #.# + + +1. License Agreement for NVIDIA Software Development Kits +--------------------------------------------------------- + + +Release Date: July 26, 2018 +--------------------------- + + +Important NoticeRead before downloading, installing, +copying or using the licensed software: +------------------------------------------------------- + +This license agreement, including exhibits attached +("Agreement”) is a legal agreement between you and NVIDIA +Corporation ("NVIDIA") and governs your use of a NVIDIA +software development kit (“SDK”). + +Each SDK has its own set of software and materials, but here +is a description of the types of items that may be included in +a SDK: source code, header files, APIs, data sets and assets +(examples include images, textures, models, scenes, videos, +native API input/output files), binary software, sample code, +libraries, utility programs, programming code and +documentation. + +This Agreement can be accepted only by an adult of legal age +of majority in the country in which the SDK is used. + +If you are entering into this Agreement on behalf of a company +or other legal entity, you represent that you have the legal +authority to bind the entity to this Agreement, in which case +“you” will mean the entity you represent. + +If you don’t have the required age or authority to accept +this Agreement, or if you don’t accept all the terms and +conditions of this Agreement, do not download, install or use +the SDK. + +You agree to use the SDK only for purposes that are permitted +by (a) this Agreement, and (b) any applicable law, regulation +or generally accepted practices or guidelines in the relevant +jurisdictions. + + +1.1. License + + +1.1.1. License Grant + +Subject to the terms of this Agreement, NVIDIA hereby grants +you a non-exclusive, non-transferable license, without the +right to sublicense (except as expressly provided in this +Agreement) to: + + 1. Install and use the SDK, + + 2. Modify and create derivative works of sample source code + delivered in the SDK, and + + 3. Distribute those portions of the SDK that are identified + in this Agreement as distributable, as incorporated in + object code format into a software application that meets + the distribution requirements indicated in this Agreement. + + +1.1.2. Distribution Requirements + +These are the distribution requirements for you to exercise +the distribution grant: + + 1. Your application must have material additional + functionality, beyond the included portions of the SDK. + + 2. The distributable portions of the SDK shall only be + accessed by your application. + + 3. The following notice shall be included in modifications + and derivative works of sample source code distributed: + “This software contains source code provided by NVIDIA + Corporation.” + + 4. Unless a developer tool is identified in this Agreement + as distributable, it is delivered for your internal use + only. + + 5. The terms under which you distribute your application + must be consistent with the terms of this Agreement, + including (without limitation) terms relating to the + license grant and license restrictions and protection of + NVIDIA’s intellectual property rights. Additionally, you + agree that you will protect the privacy, security and + legal rights of your application users. + + 6. You agree to notify NVIDIA in writing of any known or + suspected distribution or use of the SDK not in compliance + with the requirements of this Agreement, and to enforce + the terms of your agreements with respect to distributed + SDK. + + +1.1.3. Authorized Users + +You may allow employees and contractors of your entity or of +your subsidiary(ies) to access and use the SDK from your +secure network to perform work on your behalf. + +If you are an academic institution you may allow users +enrolled or employed by the academic institution to access and +use the SDK from your secure network. + +You are responsible for the compliance with the terms of this +Agreement by your authorized users. If you become aware that +your authorized users didn’t follow the terms of this +Agreement, you agree to take reasonable steps to resolve the +non-compliance and prevent new occurrences. + + +1.1.4. Pre-Release SDK + +The SDK versions identified as alpha, beta, preview or +otherwise as pre-release, may not be fully functional, may +contain errors or design flaws, and may have reduced or +different security, privacy, accessibility, availability, and +reliability standards relative to commercial versions of +NVIDIA software and materials. Use of a pre-release SDK may +result in unexpected results, loss of data, project delays or +other unpredictable damage or loss. + +You may use a pre-release SDK at your own risk, understanding +that pre-release SDKs are not intended for use in production +or business-critical systems. + +NVIDIA may choose not to make available a commercial version +of any pre-release SDK. NVIDIA may also choose to abandon +development and terminate the availability of a pre-release +SDK at any time without liability. + + +1.1.5. Updates + +NVIDIA may, at its option, make available patches, workarounds +or other updates to this SDK. Unless the updates are provided +with their separate governing terms, they are deemed part of +the SDK licensed to you as provided in this Agreement. You +agree that the form and content of the SDK that NVIDIA +provides may change without prior notice to you. While NVIDIA +generally maintains compatibility between versions, NVIDIA may +in some cases make changes that introduce incompatibilities in +future versions of the SDK. + + +1.1.6. Third Party Licenses + +The SDK may come bundled with, or otherwise include or be +distributed with, third party software licensed by a NVIDIA +supplier and/or open source software provided under an open +source license. Use of third party software is subject to the +third-party license terms, or in the absence of third party +terms, the terms of this Agreement. Copyright to third party +software is held by the copyright holders indicated in the +third-party software or license. + + +1.1.7. Reservation of Rights + +NVIDIA reserves all rights, title, and interest in and to the +SDK, not expressly granted to you under this Agreement. + + +1.2. Limitations + +The following license limitations apply to your use of the +SDK: + + 1. You may not reverse engineer, decompile or disassemble, + or remove copyright or other proprietary notices from any + portion of the SDK or copies of the SDK. + + 2. Except as expressly provided in this Agreement, you may + not copy, sell, rent, sublicense, transfer, distribute, + modify, or create derivative works of any portion of the + SDK. For clarity, you may not distribute or sublicense the + SDK as a stand-alone product. + + 3. Unless you have an agreement with NVIDIA for this + purpose, you may not indicate that an application created + with the SDK is sponsored or endorsed by NVIDIA. + + 4. You may not bypass, disable, or circumvent any + encryption, security, digital rights management or + authentication mechanism in the SDK. + + 5. You may not use the SDK in any manner that would cause it + to become subject to an open source software license. As + examples, licenses that require as a condition of use, + modification, and/or distribution that the SDK be: + + a. Disclosed or distributed in source code form; + + b. Licensed for the purpose of making derivative works; + or + + c. Redistributable at no charge. + + 6. Unless you have an agreement with NVIDIA for this + purpose, you may not use the SDK with any system or + application where the use or failure of the system or + application can reasonably be expected to threaten or + result in personal injury, death, or catastrophic loss. + Examples include use in avionics, navigation, military, + medical, life support or other life critical applications. + NVIDIA does not design, test or manufacture the SDK for + these critical uses and NVIDIA shall not be liable to you + or any third party, in whole or in part, for any claims or + damages arising from such uses. + + 7. You agree to defend, indemnify and hold harmless NVIDIA + and its affiliates, and their respective employees, + contractors, agents, officers and directors, from and + against any and all claims, damages, obligations, losses, + liabilities, costs or debt, fines, restitutions and + expenses (including but not limited to attorney’s fees + and costs incident to establishing the right of + indemnification) arising out of or related to your use of + the SDK outside of the scope of this Agreement, or not in + compliance with its terms. + + +1.3. Ownership + + 1. NVIDIA or its licensors hold all rights, title and + interest in and to the SDK and its modifications and + derivative works, including their respective intellectual + property rights, subject to your rights described in this + section. This SDK may include software and materials from + NVIDIA’s licensors, and these licensors are intended + third party beneficiaries that may enforce this Agreement + with respect to their intellectual property rights. + + 2. You hold all rights, title and interest in and to your + applications and your derivative works of the sample + source code delivered in the SDK, including their + respective intellectual property rights, subject to + NVIDIA’s rights described in this section. + + 3. You may, but don’t have to, provide to NVIDIA + suggestions, feature requests or other feedback regarding + the SDK, including possible enhancements or modifications + to the SDK. For any feedback that you voluntarily provide, + you hereby grant NVIDIA and its affiliates a perpetual, + non-exclusive, worldwide, irrevocable license to use, + reproduce, modify, license, sublicense (through multiple + tiers of sublicensees), and distribute (through multiple + tiers of distributors) it without the payment of any + royalties or fees to you. NVIDIA will use feedback at its + choice. NVIDIA is constantly looking for ways to improve + its products, so you may send feedback to NVIDIA through + the developer portal at https://developer.nvidia.com. + + +1.4. No Warranties + +THE SDK IS PROVIDED BY NVIDIA “AS IS” AND “WITH ALL +FAULTS.” TO THE MAXIMUM EXTENT PERMITTED BY LAW, NVIDIA AND +ITS AFFILIATES EXPRESSLY DISCLAIM ALL WARRANTIES OF ANY KIND +OR NATURE, WHETHER EXPRESS, IMPLIED OR STATUTORY, INCLUDING, +BUT NOT LIMITED TO, ANY WARRANTIES OF MERCHANTABILITY, FITNESS +FOR A PARTICULAR PURPOSE, TITLE, NON-INFRINGEMENT, OR THE +ABSENCE OF ANY DEFECTS THEREIN, WHETHER LATENT OR PATENT. NO +WARRANTY IS MADE ON THE BASIS OF TRADE USAGE, COURSE OF +DEALING OR COURSE OF TRADE. + + +1.5. Limitation of Liability + +TO THE MAXIMUM EXTENT PERMITTED BY LAW, NVIDIA AND ITS +AFFILIATES SHALL NOT BE LIABLE FOR ANY SPECIAL, INCIDENTAL, +PUNITIVE OR CONSEQUENTIAL DAMAGES, OR ANY LOST PROFITS, LOSS +OF USE, LOSS OF DATA OR LOSS OF GOODWILL, OR THE COSTS OF +PROCURING SUBSTITUTE PRODUCTS, ARISING OUT OF OR IN CONNECTION +WITH THIS AGREEMENT OR THE USE OR PERFORMANCE OF THE SDK, +WHETHER SUCH LIABILITY ARISES FROM ANY CLAIM BASED UPON BREACH +OF CONTRACT, BREACH OF WARRANTY, TORT (INCLUDING NEGLIGENCE), +PRODUCT LIABILITY OR ANY OTHER CAUSE OF ACTION OR THEORY OF +LIABILITY. IN NO EVENT WILL NVIDIA’S AND ITS AFFILIATES +TOTAL CUMULATIVE LIABILITY UNDER OR ARISING OUT OF THIS +AGREEMENT EXCEED US$10.00. THE NATURE OF THE LIABILITY OR THE +NUMBER OF CLAIMS OR SUITS SHALL NOT ENLARGE OR EXTEND THIS +LIMIT. + +These exclusions and limitations of liability shall apply +regardless if NVIDIA or its affiliates have been advised of +the possibility of such damages, and regardless of whether a +remedy fails its essential purpose. These exclusions and +limitations of liability form an essential basis of the +bargain between the parties, and, absent any of these +exclusions or limitations of liability, the provisions of this +Agreement, including, without limitation, the economic terms, +would be substantially different. + + +1.6. Termination + + 1. This Agreement will continue to apply until terminated by + either you or NVIDIA as described below. + + 2. If you want to terminate this Agreement, you may do so by + stopping to use the SDK. + + 3. NVIDIA may, at any time, terminate this Agreement if: + + a. (i) you fail to comply with any term of this + Agreement and the non-compliance is not fixed within + thirty (30) days following notice from NVIDIA (or + immediately if you violate NVIDIA’s intellectual + property rights); + + b. (ii) you commence or participate in any legal + proceeding against NVIDIA with respect to the SDK; or + + c. (iii) NVIDIA decides to no longer provide the SDK in + a country or, in NVIDIA’s sole discretion, the + continued use of it is no longer commercially viable. + + 4. Upon any termination of this Agreement, you agree to + promptly discontinue use of the SDK and destroy all copies + in your possession or control. Your prior distributions in + accordance with this Agreement are not affected by the + termination of this Agreement. Upon written request, you + will certify in writing that you have complied with your + commitments under this section. Upon any termination of + this Agreement all provisions survive except for the + license grant provisions. + + +1.7. General + +If you wish to assign this Agreement or your rights and +obligations, including by merger, consolidation, dissolution +or operation of law, contact NVIDIA to ask for permission. Any +attempted assignment not approved by NVIDIA in writing shall +be void and of no effect. NVIDIA may assign, delegate or +transfer this Agreement and its rights and obligations, and if +to a non-affiliate you will be notified. + +You agree to cooperate with NVIDIA and provide reasonably +requested information to verify your compliance with this +Agreement. + +This Agreement will be governed in all respects by the laws of +the United States and of the State of Delaware as those laws +are applied to contracts entered into and performed entirely +within Delaware by Delaware residents, without regard to the +conflicts of laws principles. The United Nations Convention on +Contracts for the International Sale of Goods is specifically +disclaimed. You agree to all terms of this Agreement in the +English language. + +The state or federal courts residing in Santa Clara County, +California shall have exclusive jurisdiction over any dispute +or claim arising out of this Agreement. Notwithstanding this, +you agree that NVIDIA shall still be allowed to apply for +injunctive remedies or an equivalent type of urgent legal +relief in any jurisdiction. + +If any court of competent jurisdiction determines that any +provision of this Agreement is illegal, invalid or +unenforceable, such provision will be construed as limited to +the extent necessary to be consistent with and fully +enforceable under the law and the remaining provisions will +remain in full force and effect. Unless otherwise specified, +remedies are cumulative. + +Each party acknowledges and agrees that the other is an +independent contractor in the performance of this Agreement. + +The SDK has been developed entirely at private expense and is +“commercial items” consisting of “commercial computer +software” and “commercial computer software +documentation” provided with RESTRICTED RIGHTS. Use, +duplication or disclosure by the U.S. Government or a U.S. +Government subcontractor is subject to the restrictions in +this Agreement pursuant to DFARS 227.7202-3(a) or as set forth +in subparagraphs (c)(1) and (2) of the Commercial Computer +Software - Restricted Rights clause at FAR 52.227-19, as +applicable. Contractor/manufacturer is NVIDIA, 2788 San Tomas +Expressway, Santa Clara, CA 95051. + +The SDK is subject to United States export laws and +regulations. You agree that you will not ship, transfer or +export the SDK into any country, or use the SDK in any manner, +prohibited by the United States Bureau of Industry and +Security or economic sanctions regulations administered by the +U.S. Department of Treasury’s Office of Foreign Assets +Control (OFAC), or any applicable export laws, restrictions or +regulations. These laws include restrictions on destinations, +end users and end use. By accepting this Agreement, you +confirm that you are not a resident or citizen of any country +currently embargoed by the U.S. and that you are not otherwise +prohibited from receiving the SDK. + +Any notice delivered by NVIDIA to you under this Agreement +will be delivered via mail, email or fax. You agree that any +notices that NVIDIA sends you electronically will satisfy any +legal communication requirements. Please direct your legal +notices or other correspondence to NVIDIA Corporation, 2788 +San Tomas Expressway, Santa Clara, California 95051, United +States of America, Attention: Legal Department. + +This Agreement and any exhibits incorporated into this +Agreement constitute the entire agreement of the parties with +respect to the subject matter of this Agreement and supersede +all prior negotiations or documentation exchanged between the +parties relating to this SDK license. Any additional and/or +conflicting terms on documents issued by you are null, void, +and invalid. Any amendment or waiver under this Agreement +shall be in writing and signed by representatives of both +parties. + + +2. CUDA Toolkit Supplement to Software License Agreement for +NVIDIA Software Development Kits +------------------------------------------------------------ + + +Release date: August 16, 2018 +----------------------------- + +The terms in this supplement govern your use of the NVIDIA +CUDA Toolkit SDK under the terms of your license agreement +(“Agreement”) as modified by this supplement. Capitalized +terms used but not defined below have the meaning assigned to +them in the Agreement. + +This supplement is an exhibit to the Agreement and is +incorporated as an integral part of the Agreement. In the +event of conflict between the terms in this supplement and the +terms in the Agreement, the terms in this supplement govern. + + +2.1. License Scope + +The SDK is licensed for you to develop applications only for +use in systems with NVIDIA GPUs. + + +2.2. Distribution + +The portions of the SDK that are distributable under the +Agreement are listed in Attachment A. + + +2.3. Operating Systems + +Those portions of the SDK designed exclusively for use on the +Linux or FreeBSD operating systems, or other operating systems +derived from the source code to these operating systems, may +be copied and redistributed for use in accordance with this +Agreement, provided that the object code files are not +modified in any way (except for unzipping of compressed +files). + + +2.4. Audio and Video Encoders and Decoders + +You acknowledge and agree that it is your sole responsibility +to obtain any additional third-party licenses required to +make, have made, use, have used, sell, import, and offer for +sale your products or services that include or incorporate any +third-party software and content relating to audio and/or +video encoders and decoders from, including but not limited +to, Microsoft, Thomson, Fraunhofer IIS, Sisvel S.p.A., +MPEG-LA, and Coding Technologies. NVIDIA does not grant to you +under this Agreement any necessary patent or other rights with +respect to any audio and/or video encoders and decoders. + + +2.5. Licensing + +If the distribution terms in this Agreement are not suitable +for your organization, or for any questions regarding this +Agreement, please contact NVIDIA at +nvidia-compute-license-questions@nvidia.com. + + +2.6. Attachment A + +The following portions of the SDK are distributable under the +Agreement: + +Component + +CUDA Runtime + +Windows + +cudart.dll, cudart_static.lib, cudadevrt.lib + +Mac OSX + +libcudart.dylib, libcudart_static.a, libcudadevrt.a + +Linux + +libcudart.so, libcudart_static.a, libcudadevrt.a + +Android + +libcudart.so, libcudart_static.a, libcudadevrt.a + +Component + +CUDA FFT Library + +Windows + +cufft.dll, cufftw.dll, cufft.lib, cufftw.lib + +Mac OSX + +libcufft.dylib, libcufft_static.a, libcufftw.dylib, +libcufftw_static.a + +Linux + +libcufft.so, libcufft_static.a, libcufftw.so, +libcufftw_static.a + +Android + +libcufft.so, libcufft_static.a, libcufftw.so, +libcufftw_static.a + +Component + +CUDA BLAS Library + +Windows + +cublas.dll, cublasLt.dll + +Mac OSX + +libcublas.dylib, libcublasLt.dylib, libcublas_static.a, +libcublasLt_static.a + +Linux + +libcublas.so, libcublasLt.so, libcublas_static.a, +libcublasLt_static.a + +Android + +libcublas.so, libcublasLt.so, libcublas_static.a, +libcublasLt_static.a + +Component + +NVIDIA "Drop-in" BLAS Library + +Windows + +nvblas.dll + +Mac OSX + +libnvblas.dylib + +Linux + +libnvblas.so + +Component + +CUDA Sparse Matrix Library + +Windows + +cusparse.dll, cusparse.lib + +Mac OSX + +libcusparse.dylib, libcusparse_static.a + +Linux + +libcusparse.so, libcusparse_static.a + +Android + +libcusparse.so, libcusparse_static.a + +Component + +CUDA Linear Solver Library + +Windows + +cusolver.dll, cusolver.lib + +Mac OSX + +libcusolver.dylib, libcusolver_static.a + +Linux + +libcusolver.so, libcusolver_static.a + +Android + +libcusolver.so, libcusolver_static.a + +Component + +CUDA Random Number Generation Library + +Windows + +curand.dll, curand.lib + +Mac OSX + +libcurand.dylib, libcurand_static.a + +Linux + +libcurand.so, libcurand_static.a + +Android + +libcurand.so, libcurand_static.a + +Component + +CUDA Accelerated Graph Library + +Component + +NVIDIA Performance Primitives Library + +Windows + +nppc.dll, nppc.lib, nppial.dll, nppial.lib, nppicc.dll, +nppicc.lib, nppicom.dll, nppicom.lib, nppidei.dll, +nppidei.lib, nppif.dll, nppif.lib, nppig.dll, nppig.lib, +nppim.dll, nppim.lib, nppist.dll, nppist.lib, nppisu.dll, +nppisu.lib, nppitc.dll, nppitc.lib, npps.dll, npps.lib + +Mac OSX + +libnppc.dylib, libnppc_static.a, libnppial.dylib, +libnppial_static.a, libnppicc.dylib, libnppicc_static.a, +libnppicom.dylib, libnppicom_static.a, libnppidei.dylib, +libnppidei_static.a, libnppif.dylib, libnppif_static.a, +libnppig.dylib, libnppig_static.a, libnppim.dylib, +libnppisu_static.a, libnppitc.dylib, libnppitc_static.a, +libnpps.dylib, libnpps_static.a + +Linux + +libnppc.so, libnppc_static.a, libnppial.so, +libnppial_static.a, libnppicc.so, libnppicc_static.a, +libnppicom.so, libnppicom_static.a, libnppidei.so, +libnppidei_static.a, libnppif.so, libnppif_static.a +libnppig.so, libnppig_static.a, libnppim.so, +libnppim_static.a, libnppist.so, libnppist_static.a, +libnppisu.so, libnppisu_static.a, libnppitc.so +libnppitc_static.a, libnpps.so, libnpps_static.a + +Android + +libnppc.so, libnppc_static.a, libnppial.so, +libnppial_static.a, libnppicc.so, libnppicc_static.a, +libnppicom.so, libnppicom_static.a, libnppidei.so, +libnppidei_static.a, libnppif.so, libnppif_static.a +libnppig.so, libnppig_static.a, libnppim.so, +libnppim_static.a, libnppist.so, libnppist_static.a, +libnppisu.so, libnppisu_static.a, libnppitc.so +libnppitc_static.a, libnpps.so, libnpps_static.a + +Component + +NVIDIA JPEG Library + +Linux + +libnvjpeg.so, libnvjpeg_static.a + +Component + +Internal common library required for statically linking to +cuBLAS, cuSPARSE, cuFFT, cuRAND, nvJPEG and NPP + +Mac OSX + +libculibos.a + +Linux + +libculibos.a + +Component + +NVIDIA Runtime Compilation Library and Header + +All + +nvrtc.h + +Windows + +nvrtc.dll, nvrtc-builtins.dll + +Mac OSX + +libnvrtc.dylib, libnvrtc-builtins.dylib + +Linux + +libnvrtc.so, libnvrtc-builtins.so + +Component + +NVIDIA Optimizing Compiler Library + +Windows + +nvvm.dll + +Mac OSX + +libnvvm.dylib + +Linux + +libnvvm.so + +Component + +NVIDIA Common Device Math Functions Library + +Windows + +libdevice.10.bc + +Mac OSX + +libdevice.10.bc + +Linux + +libdevice.10.bc + +Component + +CUDA Occupancy Calculation Header Library + +All + +cuda_occupancy.h + +Component + +CUDA Half Precision Headers + +All + +cuda_fp16.h, cuda_fp16.hpp + +Component + +CUDA Profiling Tools Interface (CUPTI) Library + +Windows + +cupti.dll + +Mac OSX + +libcupti.dylib + +Linux + +libcupti.so + +Component + +NVIDIA Tools Extension Library + +Windows + +nvToolsExt.dll, nvToolsExt.lib + +Mac OSX + +libnvToolsExt.dylib + +Linux + +libnvToolsExt.so + +Component + +NVIDIA CUDA Driver Libraries + +Linux + +libcuda.so, libnvidia-fatbinaryloader.so, +libnvidia-ptxjitcompiler.so + +The NVIDIA CUDA Driver Libraries are only distributable in +applications that meet this criteria: + + 1. The application was developed starting from a NVIDIA CUDA + container obtained from Docker Hub or the NVIDIA GPU + Cloud, and + + 2. The resulting application is packaged as a Docker + container and distributed to users on Docker Hub or the + NVIDIA GPU Cloud only. + + +2.7. Attachment B + + +Additional Licensing Obligations + +The following third party components included in the SOFTWARE +are licensed to Licensee pursuant to the following terms and +conditions: + + 1. Licensee's use of the GDB third party component is + subject to the terms and conditions of GNU GPL v3: + + This product includes copyrighted third-party software licensed + under the terms of the GNU General Public License v3 ("GPL v3"). + All third-party software packages are copyright by their respective + authors. GPL v3 terms and conditions are hereby incorporated into + the Agreement by this reference: http://www.gnu.org/licenses/gpl.txt + + Consistent with these licensing requirements, the software + listed below is provided under the terms of the specified + open source software licenses. To obtain source code for + software provided under licenses that require + redistribution of source code, including the GNU General + Public License (GPL) and GNU Lesser General Public License + (LGPL), contact oss-requests@nvidia.com. This offer is + valid for a period of three (3) years from the date of the + distribution of this product by NVIDIA CORPORATION. + + Component License + CUDA-GDB GPL v3 + + 2. Licensee represents and warrants that any and all third + party licensing and/or royalty payment obligations in + connection with Licensee's use of the H.264 video codecs + are solely the responsibility of Licensee. + + 3. Licensee's use of the Thrust library is subject to the + terms and conditions of the Apache License Version 2.0. + All third-party software packages are copyright by their + respective authors. Apache License Version 2.0 terms and + conditions are hereby incorporated into the Agreement by + this reference. + http://www.apache.org/licenses/LICENSE-2.0.html + + In addition, Licensee acknowledges the following notice: + Thrust includes source code from the Boost Iterator, + Tuple, System, and Random Number libraries. + + Boost Software License - Version 1.0 - August 17th, 2003 + . . . . + + Permission is hereby granted, free of charge, to any person or + organization obtaining a copy of the software and accompanying + documentation covered by this license (the "Software") to use, + reproduce, display, distribute, execute, and transmit the Software, + and to prepare derivative works of the Software, and to permit + third-parties to whom the Software is furnished to do so, all + subject to the following: + + The copyright notices in the Software and this entire statement, + including the above license grant, this restriction and the following + disclaimer, must be included in all copies of the Software, in whole + or in part, and all derivative works of the Software, unless such + copies or derivative works are solely in the form of machine-executable + object code generated by a source language processor. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE AND + NON-INFRINGEMENT. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR + ANYONE DISTRIBUTING THE SOFTWARE BE LIABLE FOR ANY DAMAGES OR + OTHER LIABILITY, WHETHER IN CONTRACT, TORT OR OTHERWISE, ARISING + FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR + OTHER DEALINGS IN THE SOFTWARE. + + 4. Licensee's use of the LLVM third party component is + subject to the following terms and conditions: + + ====================================================== + LLVM Release License + ====================================================== + University of Illinois/NCSA + Open Source License + + Copyright (c) 2003-2010 University of Illinois at Urbana-Champaign. + All rights reserved. + + Developed by: + + LLVM Team + + University of Illinois at Urbana-Champaign + + http://llvm.org + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to + deal with the Software without restriction, including without limitation the + rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + sell copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + * Redistributions of source code must retain the above copyright notice, + this list of conditions and the following disclaimers. + + * Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimers in the + documentation and/or other materials provided with the distribution. + + * Neither the names of the LLVM Team, University of Illinois at Urbana- + Champaign, nor the names of its contributors may be used to endorse or + promote products derived from this Software without specific prior + written permission. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL + THE CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR + OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, + ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + DEALINGS WITH THE SOFTWARE. + + 5. Licensee's use (e.g. nvprof) of the PCRE third party + component is subject to the following terms and + conditions: + + ------------ + PCRE LICENCE + ------------ + PCRE is a library of functions to support regular expressions whose syntax + and semantics are as close as possible to those of the Perl 5 language. + Release 8 of PCRE is distributed under the terms of the "BSD" licence, as + specified below. The documentation for PCRE, supplied in the "doc" + directory, is distributed under the same terms as the software itself. The + basic library functions are written in C and are freestanding. Also + included in the distribution is a set of C++ wrapper functions, and a just- + in-time compiler that can be used to optimize pattern matching. These are + both optional features that can be omitted when the library is built. + + THE BASIC LIBRARY FUNCTIONS + --------------------------- + Written by: Philip Hazel + Email local part: ph10 + Email domain: cam.ac.uk + University of Cambridge Computing Service, + Cambridge, England. + Copyright (c) 1997-2012 University of Cambridge + All rights reserved. + + PCRE JUST-IN-TIME COMPILATION SUPPORT + ------------------------------------- + Written by: Zoltan Herczeg + Email local part: hzmester + Emain domain: freemail.hu + Copyright(c) 2010-2012 Zoltan Herczeg + All rights reserved. + + STACK-LESS JUST-IN-TIME COMPILER + -------------------------------- + Written by: Zoltan Herczeg + Email local part: hzmester + Emain domain: freemail.hu + Copyright(c) 2009-2012 Zoltan Herczeg + All rights reserved. + + THE C++ WRAPPER FUNCTIONS + ------------------------- + Contributed by: Google Inc. + Copyright (c) 2007-2012, Google Inc. + All rights reserved. + + THE "BSD" LICENCE + ----------------- + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are met: + + * Redistributions of source code must retain the above copyright notice, + this list of conditions and the following disclaimer. + + * Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + + * Neither the name of the University of Cambridge nor the name of Google + Inc. nor the names of their contributors may be used to endorse or + promote products derived from this software without specific prior + written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE + LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + POSSIBILITY OF SUCH DAMAGE. + + 6. Some of the cuBLAS library routines were written by or + derived from code written by Vasily Volkov and are subject + to the Modified Berkeley Software Distribution License as + follows: + + Copyright (c) 2007-2009, Regents of the University of California + + All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above + copyright notice, this list of conditions and the following + disclaimer in the documentation and/or other materials provided + with the distribution. + * Neither the name of the University of California, Berkeley nor + the names of its contributors may be used to endorse or promote + products derived from this software without specific prior + written permission. + + THIS SOFTWARE IS PROVIDED BY THE AUTHOR "AS IS" AND ANY EXPRESS OR + IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + DISCLAIMED. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, + INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) + HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING + IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + POSSIBILITY OF SUCH DAMAGE. + + 7. Some of the cuBLAS library routines were written by or + derived from code written by Davide Barbieri and are + subject to the Modified Berkeley Software Distribution + License as follows: + + Copyright (c) 2008-2009 Davide Barbieri @ University of Rome Tor Vergata. + + All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above + copyright notice, this list of conditions and the following + disclaimer in the documentation and/or other materials provided + with the distribution. + * The name of the author may not be used to endorse or promote + products derived from this software without specific prior + written permission. + + THIS SOFTWARE IS PROVIDED BY THE AUTHOR "AS IS" AND ANY EXPRESS OR + IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + DISCLAIMED. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, + INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) + HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING + IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + POSSIBILITY OF SUCH DAMAGE. + + 8. Some of the cuBLAS library routines were derived from + code developed by the University of Tennessee and are + subject to the Modified Berkeley Software Distribution + License as follows: + + Copyright (c) 2010 The University of Tennessee. + + All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above + copyright notice, this list of conditions and the following + disclaimer listed in this license in the documentation and/or + other materials provided with the distribution. + * Neither the name of the copyright holders nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + 9. Some of the cuBLAS library routines were written by or + derived from code written by Jonathan Hogg and are subject + to the Modified Berkeley Software Distribution License as + follows: + + Copyright (c) 2012, The Science and Technology Facilities Council (STFC). + + All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above + copyright notice, this list of conditions and the following + disclaimer in the documentation and/or other materials provided + with the distribution. + * Neither the name of the STFC nor the names of its contributors + may be used to endorse or promote products derived from this + software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE STFC BE + LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR + BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, + WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE + OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN + IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + 10. Some of the cuBLAS library routines were written by or + derived from code written by Ahmad M. Abdelfattah, David + Keyes, and Hatem Ltaief, and are subject to the Apache + License, Version 2.0, as follows: + + -- (C) Copyright 2013 King Abdullah University of Science and Technology + Authors: + Ahmad Abdelfattah (ahmad.ahmad@kaust.edu.sa) + David Keyes (david.keyes@kaust.edu.sa) + Hatem Ltaief (hatem.ltaief@kaust.edu.sa) + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions + are met: + + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + * Neither the name of the King Abdullah University of Science and + Technology nor the names of its contributors may be used to endorse + or promote products derived from this software without specific prior + written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDERS OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE + + 11. Some of the cuSPARSE library routines were written by or + derived from code written by Li-Wen Chang and are subject + to the NCSA Open Source License as follows: + + Copyright (c) 2012, University of Illinois. + + All rights reserved. + + Developed by: IMPACT Group, University of Illinois, http://impact.crhc.illinois.edu + + Permission is hereby granted, free of charge, to any person obtaining + a copy of this software and associated documentation files (the + "Software"), to deal with the Software without restriction, including + without limitation the rights to use, copy, modify, merge, publish, + distribute, sublicense, and/or sell copies of the Software, and to + permit persons to whom the Software is furnished to do so, subject to + the following conditions: + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above + copyright notice, this list of conditions and the following + disclaimers in the documentation and/or other materials provided + with the distribution. + * Neither the names of IMPACT Group, University of Illinois, nor + the names of its contributors may be used to endorse or promote + products derived from this Software without specific prior + written permission. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND + NONINFRINGEMENT. IN NO EVENT SHALL THE CONTRIBUTORS OR COPYRIGHT + HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER + IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR + IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS WITH THE + SOFTWARE. + + 12. Some of the cuRAND library routines were written by or + derived from code written by Mutsuo Saito and Makoto + Matsumoto and are subject to the following license: + + Copyright (c) 2009, 2010 Mutsuo Saito, Makoto Matsumoto and Hiroshima + University. All rights reserved. + + Copyright (c) 2011 Mutsuo Saito, Makoto Matsumoto, Hiroshima + University and University of Tokyo. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above + copyright notice, this list of conditions and the following + disclaimer in the documentation and/or other materials provided + with the distribution. + * Neither the name of the Hiroshima University nor the names of + its contributors may be used to endorse or promote products + derived from this software without specific prior written + permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + 13. Some of the cuRAND library routines were derived from + code developed by D. E. Shaw Research and are subject to + the following license: + + Copyright 2010-2011, D. E. Shaw Research. + + All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + * Redistributions of source code must retain the above copyright + notice, this list of conditions, and the following disclaimer. + * Redistributions in binary form must reproduce the above + copyright notice, this list of conditions, and the following + disclaimer in the documentation and/or other materials provided + with the distribution. + * Neither the name of D. E. Shaw Research nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + 14. Some of the Math library routines were written by or + derived from code developed by Norbert Juffa and are + subject to the following license: + + Copyright (c) 2015-2017, Norbert Juffa + All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions + are met: + + 1. Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + 2. Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + 15. Licensee's use of the lz4 third party component is + subject to the following terms and conditions: + + Copyright (C) 2011-2013, Yann Collet. + BSD 2-Clause License (http://www.opensource.org/licenses/bsd-license.php) + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above + copyright notice, this list of conditions and the following disclaimer + in the documentation and/or other materials provided with the + distribution. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + 16. The NPP library uses code from the Boost Math Toolkit, + and is subject to the following license: + + Boost Software License - Version 1.0 - August 17th, 2003 + . . . . + + Permission is hereby granted, free of charge, to any person or + organization obtaining a copy of the software and accompanying + documentation covered by this license (the "Software") to use, + reproduce, display, distribute, execute, and transmit the Software, + and to prepare derivative works of the Software, and to permit + third-parties to whom the Software is furnished to do so, all + subject to the following: + + The copyright notices in the Software and this entire statement, + including the above license grant, this restriction and the following + disclaimer, must be included in all copies of the Software, in whole + or in part, and all derivative works of the Software, unless such + copies or derivative works are solely in the form of machine-executable + object code generated by a source language processor. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE AND + NON-INFRINGEMENT. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR + ANYONE DISTRIBUTING THE SOFTWARE BE LIABLE FOR ANY DAMAGES OR + OTHER LIABILITY, WHETHER IN CONTRACT, TORT OR OTHERWISE, ARISING + FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR + OTHER DEALINGS IN THE SOFTWARE. + + 17. Portions of the Nsight Eclipse Edition is subject to the + following license: + + The Eclipse Foundation makes available all content in this plug-in + ("Content"). Unless otherwise indicated below, the Content is provided + to you under the terms and conditions of the Eclipse Public License + Version 1.0 ("EPL"). A copy of the EPL is available at http:// + www.eclipse.org/legal/epl-v10.html. For purposes of the EPL, "Program" + will mean the Content. + + If you did not receive this Content directly from the Eclipse + Foundation, the Content is being redistributed by another party + ("Redistributor") and different terms and conditions may apply to your + use of any object code in the Content. Check the Redistributor's + license that was provided with the Content. If no such license exists, + contact the Redistributor. Unless otherwise indicated below, the terms + and conditions of the EPL still apply to any source code in the + Content and such source code may be obtained at http://www.eclipse.org. + + 18. Some of the cuBLAS library routines uses code from + OpenAI, which is subject to the following license: + + License URL + https://github.com/openai/openai-gemm/blob/master/LICENSE + + License Text + The MIT License + + Copyright (c) 2016 OpenAI (http://openai.com), 2016 Google Inc. + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in + all copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + THE SOFTWARE. + + 19. Licensee's use of the Visual Studio Setup Configuration + Samples is subject to the following license: + + The MIT License (MIT) + Copyright (C) Microsoft Corporation. All rights reserved. + + Permission is hereby granted, free of charge, to any person + obtaining a copy of this software and associated documentation + files (the "Software"), to deal in the Software without restriction, + including without limitation the rights to use, copy, modify, merge, + publish, distribute, sublicense, and/or sell copies of the Software, + and to permit persons to whom the Software is furnished to do so, + subject to the following conditions: + + The above copyright notice and this permission notice shall be included + in all copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + + 20. Licensee's use of linmath.h header for CPU functions for + GL vector/matrix operations from lunarG is subject to the + Apache License Version 2.0. + + 21. The DX12-CUDA sample uses the d3dx12.h header, which is + subject to the MIT license . + +----------------- diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/nvidia_cublas_cu12-12.8.4.1.dist-info/METADATA b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/nvidia_cublas_cu12-12.8.4.1.dist-info/METADATA new file mode 100644 index 0000000000000000000000000000000000000000..cb533db38cb400ef0b0e71cea45e7ce506422bfd --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/nvidia_cublas_cu12-12.8.4.1.dist-info/METADATA @@ -0,0 +1,44 @@ +Metadata-Version: 2.2 +Name: nvidia-cublas-cu12 +Version: 12.8.4.1 +Summary: CUBLAS native runtime libraries +Home-page: https://developer.nvidia.com/cuda-zone +Author: Nvidia CUDA Installer Team +Author-email: compute_installer@nvidia.com +License: NVIDIA Proprietary Software +Keywords: cuda,nvidia,runtime,machine learning,deep learning +Classifier: Development Status :: 4 - Beta +Classifier: Intended Audience :: Developers +Classifier: Intended Audience :: Education +Classifier: Intended Audience :: Science/Research +Classifier: License :: Other/Proprietary License +Classifier: Natural Language :: English +Classifier: Programming Language :: Python :: 3 +Classifier: Programming Language :: Python :: 3.5 +Classifier: Programming Language :: Python :: 3.6 +Classifier: Programming Language :: Python :: 3.7 +Classifier: Programming Language :: Python :: 3.8 +Classifier: Programming Language :: Python :: 3.9 +Classifier: Programming Language :: Python :: 3.10 +Classifier: Programming Language :: Python :: 3.11 +Classifier: Programming Language :: Python :: 3 :: Only +Classifier: Topic :: Scientific/Engineering +Classifier: Topic :: Scientific/Engineering :: Mathematics +Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence +Classifier: Topic :: Software Development +Classifier: Topic :: Software Development :: Libraries +Classifier: Operating System :: Microsoft :: Windows +Classifier: Operating System :: POSIX :: Linux +Requires-Python: >=3 +License-File: License.txt +Dynamic: author +Dynamic: author-email +Dynamic: classifier +Dynamic: description +Dynamic: home-page +Dynamic: keywords +Dynamic: license +Dynamic: requires-python +Dynamic: summary + +CUBLAS native runtime libraries diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/nvidia_cublas_cu12-12.8.4.1.dist-info/RECORD b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/nvidia_cublas_cu12-12.8.4.1.dist-info/RECORD new file mode 100644 index 0000000000000000000000000000000000000000..d794952a3781dd8c9ca175c04b28d6286811e807 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/nvidia_cublas_cu12-12.8.4.1.dist-info/RECORD @@ -0,0 +1,23 @@ +nvidia/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +nvidia/__pycache__/__init__.cpython-312.pyc,, +nvidia/cublas/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +nvidia/cublas/__pycache__/__init__.cpython-312.pyc,, +nvidia/cublas/include/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +nvidia/cublas/include/__pycache__/__init__.cpython-312.pyc,, +nvidia/cublas/include/cublas.h,sha256=a0lLqy-k47NuwyDjuueC3W0Mpc908MTU7o5sMJqE-1w,41246 +nvidia/cublas/include/cublasLt.h,sha256=oH9TR01H5CWfGPIjJk6-Ljg8eQOu0gho7TNt83gCmwg,102451 +nvidia/cublas/include/cublasXt.h,sha256=CW9dyXYGSUW1wEXrVVyhU6OxBK1PUvMoYdVGlQT7L9A,37380 +nvidia/cublas/include/cublas_api.h,sha256=A4Jvv9elvoJIDw8ReUP7qBnSxEvdslc-ghW_ycq_qAU,374363 +nvidia/cublas/include/cublas_v2.h,sha256=qxMdB5jb97luEfw61LEAB-Wlr8A9DLBvO4rRypDCNKw,15460 +nvidia/cublas/include/nvblas.h,sha256=dXCLR-2oUiJFzLsDtIAK09m42ct4G0HWdYzBUuDPXpc,23341 +nvidia/cublas/lib/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +nvidia/cublas/lib/__pycache__/__init__.cpython-312.pyc,, +nvidia/cublas/lib/libcublas.so.12,sha256=Axzmwsv7uUaPBAUnyrXFmQac5WCec-KPh1A4gQY-rCE,116388640 +nvidia/cublas/lib/libcublasLt.so.12,sha256=ELXmYxz4EVxmHriV7RUzgmMItY95VkZvU9I2pAybYiw,751771728 +nvidia/cublas/lib/libnvblas.so.12,sha256=nSRIINMpZvCvR6_mye5vNLIw-v6H6gFB3HEkGl0rNAw,753824 +nvidia_cublas_cu12-12.8.4.1.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4 +nvidia_cublas_cu12-12.8.4.1.dist-info/License.txt,sha256=rW9YU_ugyg0VnQ9Y1JrkmDDC-Mk_epJki5zpCttMbM0,59262 +nvidia_cublas_cu12-12.8.4.1.dist-info/METADATA,sha256=MQVW8cLDHsUX1LBJ6vEjdAgdHFLI702Ozqsqrx7f8t8,1683 +nvidia_cublas_cu12-12.8.4.1.dist-info/RECORD,, +nvidia_cublas_cu12-12.8.4.1.dist-info/WHEEL,sha256=wwQXGCXJQEUF59fKjBnVVzOkzi78E6rE_-QuUR4ZT4w,109 +nvidia_cublas_cu12-12.8.4.1.dist-info/top_level.txt,sha256=fTkAtiFuL16nUrB9ytDDtpytz2t0B4NvYTnRzwAhO14,7 diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/nvidia_cublas_cu12-12.8.4.1.dist-info/WHEEL b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/nvidia_cublas_cu12-12.8.4.1.dist-info/WHEEL new file mode 100644 index 0000000000000000000000000000000000000000..6a0d9dc1b21a540a280fda3e78767d22a3b7ed6e --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/nvidia_cublas_cu12-12.8.4.1.dist-info/WHEEL @@ -0,0 +1,5 @@ +Wheel-Version: 1.0 +Generator: setuptools (75.8.1) +Root-Is-Purelib: true +Tag: py3-none-manylinux_2_27_x86_64 + diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/nvidia_cublas_cu12-12.8.4.1.dist-info/top_level.txt b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/nvidia_cublas_cu12-12.8.4.1.dist-info/top_level.txt new file mode 100644 index 0000000000000000000000000000000000000000..862f7abf232cdfbb928609856247292e81c9decb --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/nvidia_cublas_cu12-12.8.4.1.dist-info/top_level.txt @@ -0,0 +1 @@ +nvidia diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/__init__.pxd b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/__init__.pxd new file mode 100644 index 0000000000000000000000000000000000000000..8cc54b4c6bfdaa0e347b3927d7932934916a1ade --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/__init__.pxd @@ -0,0 +1,42 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from libcpp.memory cimport shared_ptr +from pyarrow.includes.libarrow cimport (CArray, CBuffer, CDataType, + CField, CRecordBatch, CSchema, + CTable, CTensor, CSparseCOOTensor, + CSparseCSRMatrix, CSparseCSCMatrix, + CSparseCSFTensor) + +cdef extern from "arrow/python/pyarrow.h" namespace "arrow::py": + cdef int import_pyarrow() except -1 + cdef object wrap_buffer(const shared_ptr[CBuffer]& buffer) + cdef object wrap_data_type(const shared_ptr[CDataType]& type) + cdef object wrap_field(const shared_ptr[CField]& field) + cdef object wrap_schema(const shared_ptr[CSchema]& schema) + cdef object wrap_array(const shared_ptr[CArray]& sp_array) + cdef object wrap_tensor(const shared_ptr[CTensor]& sp_tensor) + cdef object wrap_sparse_tensor_coo( + const shared_ptr[CSparseCOOTensor]& sp_sparse_tensor) + cdef object wrap_sparse_tensor_csr( + const shared_ptr[CSparseCSRMatrix]& sp_sparse_tensor) + cdef object wrap_sparse_tensor_csc( + const shared_ptr[CSparseCSCMatrix]& sp_sparse_tensor) + cdef object wrap_sparse_tensor_csf( + const shared_ptr[CSparseCSFTensor]& sp_sparse_tensor) + cdef object wrap_table(const shared_ptr[CTable]& ctable) + cdef object wrap_batch(const shared_ptr[CRecordBatch]& cbatch) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/_acero.pxd b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/_acero.pxd new file mode 100644 index 0000000000000000000000000000000000000000..4553aee9d6f16c391340aa45489471bdcfe0cb76 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/_acero.pxd @@ -0,0 +1,44 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# cython: language_level = 3 + +from pyarrow.lib cimport * +from pyarrow.includes.common cimport * +from pyarrow.includes.libarrow cimport * +from pyarrow.includes.libarrow_acero cimport * + + +cdef class ExecNodeOptions(_Weakrefable): + cdef: + shared_ptr[CExecNodeOptions] wrapped + + cdef void init(self, const shared_ptr[CExecNodeOptions]& sp) + cdef inline shared_ptr[CExecNodeOptions] unwrap(self) nogil + + +cdef class Declaration(_Weakrefable): + + cdef: + CDeclaration decl + + cdef void init(self, const CDeclaration& c_decl) + + @staticmethod + cdef wrap(const CDeclaration& c_decl) + + cdef inline CDeclaration unwrap(self) nogil diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/_acero.pyx b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/_acero.pyx new file mode 100644 index 0000000000000000000000000000000000000000..9e8cbd65be224bb255448b580b44f0575942fc1e --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/_acero.pyx @@ -0,0 +1,608 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# --------------------------------------------------------------------- +# Low-level Acero bindings + +# cython: profile=False +# distutils: language = c++ +# cython: language_level = 3 + +from pyarrow.includes.common cimport * +from pyarrow.includes.libarrow cimport * +from pyarrow.includes.libarrow_acero cimport * +from pyarrow.lib cimport (Table, pyarrow_unwrap_table, pyarrow_wrap_table, + RecordBatchReader) +from pyarrow.lib import frombytes, tobytes +from pyarrow._compute cimport ( + Expression, FunctionOptions, _ensure_field_ref, _true, + unwrap_null_placement, unwrap_sort_order +) + + +cdef class ExecNodeOptions(_Weakrefable): + """ + Base class for the node options. + + Use one of the subclasses to construct an options object. + """ + __slots__ = () # avoid mistakingly creating attributes + + cdef void init(self, const shared_ptr[CExecNodeOptions]& sp): + self.wrapped = sp + + cdef inline shared_ptr[CExecNodeOptions] unwrap(self) nogil: + return self.wrapped + + +cdef class _TableSourceNodeOptions(ExecNodeOptions): + + def _set_options(self, Table table): + cdef: + shared_ptr[CTable] c_table + + c_table = pyarrow_unwrap_table(table) + self.wrapped.reset( + new CTableSourceNodeOptions(c_table) + ) + + +class TableSourceNodeOptions(_TableSourceNodeOptions): + """ + A Source node which accepts a table. + + This is the option class for the "table_source" node factory. + + Parameters + ---------- + table : pyarrow.Table + The table which acts as the data source. + """ + + def __init__(self, Table table): + self._set_options(table) + + +cdef class _FilterNodeOptions(ExecNodeOptions): + + def _set_options(self, Expression filter_expression not None): + self.wrapped.reset( + new CFilterNodeOptions(filter_expression.unwrap()) + ) + + +class FilterNodeOptions(_FilterNodeOptions): + """ + Make a node which excludes some rows from batches passed through it. + + This is the option class for the "filter" node factory. + + The "filter" operation provides an option to define data filtering + criteria. It selects rows where the given expression evaluates to true. + Filters can be written using pyarrow.compute.Expression, and the + expression must have a return type of boolean. + + Parameters + ---------- + filter_expression : pyarrow.compute.Expression + """ + + def __init__(self, Expression filter_expression): + self._set_options(filter_expression) + + +cdef class _ProjectNodeOptions(ExecNodeOptions): + + def _set_options(self, expressions, names=None): + cdef: + Expression expr + vector[CExpression] c_expressions + vector[c_string] c_names + + for expr in expressions: + c_expressions.push_back(expr.unwrap()) + + if names is not None: + if len(names) != len(expressions): + raise ValueError( + "The number of names should be equal to the number of expressions" + ) + + for name in names: + c_names.push_back(tobytes(name)) + + self.wrapped.reset( + new CProjectNodeOptions(c_expressions, c_names) + ) + else: + self.wrapped.reset( + new CProjectNodeOptions(c_expressions) + ) + + +class ProjectNodeOptions(_ProjectNodeOptions): + """ + Make a node which executes expressions on input batches, + producing batches of the same length with new columns. + + This is the option class for the "project" node factory. + + The "project" operation rearranges, deletes, transforms, and + creates columns. Each output column is computed by evaluating + an expression against the source record batch. These must be + scalar expressions (expressions consisting of scalar literals, + field references and scalar functions, i.e. elementwise functions + that return one value for each input row independent of the value + of all other rows). + + Parameters + ---------- + expressions : list of pyarrow.compute.Expression + List of expressions to evaluate against the source batch. This must + be scalar expressions. + names : list of str, optional + List of names for each of the output columns (same length as + `expressions`). If `names` is not provided, the string + representations of exprs will be used. + """ + + def __init__(self, expressions, names=None): + self._set_options(expressions, names) + + +cdef class _AggregateNodeOptions(ExecNodeOptions): + + def _set_options(self, aggregates, keys=None): + cdef: + CAggregate c_aggr + vector[CAggregate] c_aggregations + vector[CFieldRef] c_keys + + for arg_names, func_name, opts, name in aggregates: + c_aggr.function = tobytes(func_name) + if opts is not None: + c_aggr.options = (opts).wrapped + else: + c_aggr.options = nullptr + if not isinstance(arg_names, (list, tuple)): + arg_names = [arg_names] + for arg in arg_names: + c_aggr.target.push_back(_ensure_field_ref(arg)) + c_aggr.name = tobytes(name) + + c_aggregations.push_back(move(c_aggr)) + + if keys is None: + keys = [] + for name in keys: + c_keys.push_back(_ensure_field_ref(name)) + + self.wrapped.reset( + new CAggregateNodeOptions(c_aggregations, c_keys) + ) + + +class AggregateNodeOptions(_AggregateNodeOptions): + """ + Make a node which aggregates input batches, optionally grouped by keys. + + This is the option class for the "aggregate" node factory. + + Acero supports two types of aggregates: "scalar" aggregates, + and "hash" aggregates. Scalar aggregates reduce an array or scalar + input to a single scalar output (e.g. computing the mean of a column). + Hash aggregates act like GROUP BY in SQL and first partition data + based on one or more key columns, then reduce the data in each partition. + The aggregate node supports both types of computation, and can compute + any number of aggregations at once. + + Parameters + ---------- + aggregates : list of tuples + Aggregations which will be applied to the targeted fields. + Specified as a list of tuples, where each tuple is one aggregation + specification and consists of: aggregation target column(s) followed + by function name, aggregation function options object and the + output field name. + The target column(s) specification can be a single field reference, + an empty list or a list of fields unary, nullary and n-ary aggregation + functions respectively. Each field reference can be a string + column name or expression. + keys : list of field references, optional + Keys by which aggregations will be grouped. Each key can reference + a field using a string name or expression. + """ + + def __init__(self, aggregates, keys=None): + self._set_options(aggregates, keys) + + +cdef class _OrderByNodeOptions(ExecNodeOptions): + + def _set_options(self, sort_keys, null_placement): + cdef: + vector[CSortKey] c_sort_keys + + for name, order in sort_keys: + c_sort_keys.push_back( + CSortKey(_ensure_field_ref(name), unwrap_sort_order(order)) + ) + + self.wrapped.reset( + new COrderByNodeOptions( + COrdering(c_sort_keys, unwrap_null_placement(null_placement)) + ) + ) + + +class OrderByNodeOptions(_OrderByNodeOptions): + """ + Make a node which applies a new ordering to the data. + + Currently this node works by accumulating all data, sorting, and then + emitting the new data with an updated batch index. + Larger-than-memory sort is not currently supported. + + This is the option class for the "order_by" node factory. + + Parameters + ---------- + sort_keys : sequence of (name, order) tuples + Names of field/column keys to sort the input on, + along with the order each field/column is sorted in. + Accepted values for `order` are "ascending", "descending". + Each field reference can be a string column name or expression. + null_placement : str, default "at_end" + Where nulls in input should be sorted, only applying to + columns/fields mentioned in `sort_keys`. + Accepted values are "at_start", "at_end". + """ + + def __init__(self, sort_keys=(), *, null_placement="at_end"): + self._set_options(sort_keys, null_placement) + + +cdef class _HashJoinNodeOptions(ExecNodeOptions): + + def _set_options( + self, join_type, left_keys, right_keys, left_output=None, right_output=None, + output_suffix_for_left="", output_suffix_for_right="", + ): + cdef: + CJoinType c_join_type + vector[CFieldRef] c_left_keys + vector[CFieldRef] c_right_keys + vector[CFieldRef] c_left_output + vector[CFieldRef] c_right_output + + # join type + if join_type == "left semi": + c_join_type = CJoinType_LEFT_SEMI + elif join_type == "right semi": + c_join_type = CJoinType_RIGHT_SEMI + elif join_type == "left anti": + c_join_type = CJoinType_LEFT_ANTI + elif join_type == "right anti": + c_join_type = CJoinType_RIGHT_ANTI + elif join_type == "inner": + c_join_type = CJoinType_INNER + elif join_type == "left outer": + c_join_type = CJoinType_LEFT_OUTER + elif join_type == "right outer": + c_join_type = CJoinType_RIGHT_OUTER + elif join_type == "full outer": + c_join_type = CJoinType_FULL_OUTER + else: + raise ValueError("Unsupported join type") + + # left/right keys + if not isinstance(left_keys, (list, tuple)): + left_keys = [left_keys] + for key in left_keys: + c_left_keys.push_back(_ensure_field_ref(key)) + if not isinstance(right_keys, (list, tuple)): + right_keys = [right_keys] + for key in right_keys: + c_right_keys.push_back(_ensure_field_ref(key)) + + # left/right output fields + if left_output is not None and right_output is not None: + for colname in left_output: + c_left_output.push_back(_ensure_field_ref(colname)) + for colname in right_output: + c_right_output.push_back(_ensure_field_ref(colname)) + + self.wrapped.reset( + new CHashJoinNodeOptions( + c_join_type, c_left_keys, c_right_keys, + c_left_output, c_right_output, + _true, + tobytes(output_suffix_for_left), + tobytes(output_suffix_for_right) + ) + ) + else: + self.wrapped.reset( + new CHashJoinNodeOptions( + c_join_type, c_left_keys, c_right_keys, + _true, + tobytes(output_suffix_for_left), + tobytes(output_suffix_for_right) + ) + ) + + +class HashJoinNodeOptions(_HashJoinNodeOptions): + """ + Make a node which implements join operation using hash join strategy. + + This is the option class for the "hashjoin" node factory. + + Parameters + ---------- + join_type : str + Type of join. One of "left semi", "right semi", "left anti", + "right anti", "inner", "left outer", "right outer", "full outer". + left_keys : str, Expression or list + Key fields from left input. Each key can be a string column name + or a field expression, or a list of such field references. + right_keys : str, Expression or list + Key fields from right input. See `left_keys` for details. + left_output : list, optional + List of output fields passed from left input. If left and right + output fields are not specified, all valid fields from both left and + right input will be output. Each field can be a string column name + or a field expression. + right_output : list, optional + List of output fields passed from right input. If left and right + output fields are not specified, all valid fields from both left and + right input will be output. Each field can be a string column name + or a field expression. + output_suffix_for_left : str + Suffix added to names of output fields coming from left input + (used to distinguish, if necessary, between fields of the same + name in left and right input and can be left empty if there are + no name collisions). + output_suffix_for_right : str + Suffix added to names of output fields coming from right input, + see `output_suffix_for_left` for details. + """ + + def __init__( + self, join_type, left_keys, right_keys, left_output=None, right_output=None, + output_suffix_for_left="", output_suffix_for_right="" + ): + self._set_options( + join_type, left_keys, right_keys, left_output, right_output, + output_suffix_for_left, output_suffix_for_right + ) + + +cdef class _AsofJoinNodeOptions(ExecNodeOptions): + + def _set_options(self, left_on, left_by, right_on, right_by, tolerance): + cdef: + vector[CFieldRef] c_left_by + vector[CFieldRef] c_right_by + CAsofJoinKeys c_left_keys + CAsofJoinKeys c_right_keys + vector[CAsofJoinKeys] c_input_keys + + # Prepare left AsofJoinNodeOption::Keys + if not isinstance(left_by, (list, tuple)): + left_by = [left_by] + for key in left_by: + c_left_by.push_back(_ensure_field_ref(key)) + + c_left_keys.on_key = _ensure_field_ref(left_on) + c_left_keys.by_key = c_left_by + + c_input_keys.push_back(c_left_keys) + + # Prepare right AsofJoinNodeOption::Keys + if not isinstance(right_by, (list, tuple)): + right_by = [right_by] + for key in right_by: + c_right_by.push_back(_ensure_field_ref(key)) + + c_right_keys.on_key = _ensure_field_ref(right_on) + c_right_keys.by_key = c_right_by + + c_input_keys.push_back(c_right_keys) + + self.wrapped.reset( + new CAsofJoinNodeOptions( + c_input_keys, + tolerance, + ) + ) + + +class AsofJoinNodeOptions(_AsofJoinNodeOptions): + """ + Make a node which implements 'as of join' operation. + + This is the option class for the "asofjoin" node factory. + + Parameters + ---------- + left_on : str, Expression + The left key on which the join operation should be performed. + Can be a string column name or a field expression. + + An inexact match is used on the "on" key, i.e. a row is considered a + match if and only if left_on - tolerance <= right_on <= left_on. + + The input dataset must be sorted by the "on" key. Must be a single + field of a common type. + + Currently, the "on" key must be an integer, date, or timestamp type. + left_by: str, Expression or list + The left keys on which the join operation should be performed. + Exact equality is used for each field of the "by" keys. + Each key can be a string column name or a field expression, + or a list of such field references. + right_on : str, Expression + The right key on which the join operation should be performed. + See `left_on` for details. + right_by: str, Expression or list + The right keys on which the join operation should be performed. + See `left_by` for details. + tolerance : int + The tolerance to use for the asof join. The tolerance is interpreted in + the same units as the "on" key. + """ + + def __init__(self, left_on, left_by, right_on, right_by, tolerance): + self._set_options(left_on, left_by, right_on, right_by, tolerance) + + +cdef class Declaration(_Weakrefable): + """ + Helper class for declaring the nodes of an ExecPlan. + + A Declaration represents an unconstructed ExecNode, and potentially + more since its inputs may also be Declarations or when constructed + with ``from_sequence``. + + The possible ExecNodes to use are registered with a name, + the "factory name", and need to be specified using this name, together + with its corresponding ExecNodeOptions subclass. + + Parameters + ---------- + factory_name : str + The ExecNode factory name, such as "table_source", "filter", + "project" etc. See the ExecNodeOptions subclasses for the exact + factory names to use. + options : ExecNodeOptions + Corresponding ExecNodeOptions subclass (matching the factory name). + inputs : list of Declaration, optional + Input nodes for this declaration. Optional if the node is a source + node, or when the declaration gets combined later with + ``from_sequence``. + + Returns + ------- + Declaration + """ + cdef void init(self, const CDeclaration& c_decl): + self.decl = c_decl + + @staticmethod + cdef wrap(const CDeclaration& c_decl): + cdef Declaration self = Declaration.__new__(Declaration) + self.init(c_decl) + return self + + cdef inline CDeclaration unwrap(self) nogil: + return self.decl + + def __init__(self, factory_name, ExecNodeOptions options, inputs=None): + cdef: + c_string c_factory_name + CDeclaration c_decl + vector[CDeclaration.Input] c_inputs + + c_factory_name = tobytes(factory_name) + + if inputs is not None: + for ipt in inputs: + c_inputs.push_back( + CDeclaration.Input((ipt).unwrap()) + ) + + c_decl = CDeclaration(c_factory_name, c_inputs, options.unwrap()) + self.init(c_decl) + + @staticmethod + def from_sequence(decls): + """ + Convenience factory for the common case of a simple sequence of nodes. + + Each of the declarations will be appended to the inputs of the + subsequent declaration, and the final modified declaration will + be returned. + + Parameters + ---------- + decls : list of Declaration + + Returns + ------- + Declaration + """ + cdef: + vector[CDeclaration] c_decls + CDeclaration c_decl + + for decl in decls: + c_decls.push_back(( decl).unwrap()) + + c_decl = CDeclaration.Sequence(c_decls) + return Declaration.wrap(c_decl) + + def __str__(self): + return frombytes(GetResultValue(DeclarationToString(self.decl))) + + def __repr__(self): + return "\n{0}".format(str(self)) + + def to_table(self, bint use_threads=True): + """ + Run the declaration and collect the results into a table. + + This method will implicitly add a sink node to the declaration + to collect results into a table. It will then create an ExecPlan + from the declaration, start the exec plan, block until the plan + has finished, and return the created table. + + Parameters + ---------- + use_threads : bool, default True + If set to False, then all CPU work will be done on the calling + thread. I/O tasks will still happen on the I/O executor + and may be multi-threaded (but should not use significant CPU + resources). + + Returns + ------- + pyarrow.Table + """ + cdef: + shared_ptr[CTable] c_table + + with nogil: + c_table = GetResultValue(DeclarationToTable(self.unwrap(), use_threads)) + return pyarrow_wrap_table(c_table) + + def to_reader(self, bint use_threads=True): + """Run the declaration and return results as a RecordBatchReader. + + For details about the parameters, see `to_table`. + + Returns + ------- + pyarrow.RecordBatchReader + """ + cdef: + RecordBatchReader reader + reader = RecordBatchReader.__new__(RecordBatchReader) + reader.reader.reset( + GetResultValue(DeclarationToReader(self.unwrap(), use_threads)).release() + ) + return reader diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/_compute.pxd b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/_compute.pxd new file mode 100644 index 0000000000000000000000000000000000000000..29b37da3ac4ef36106b10a09d7583bdba8d1a260 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/_compute.pxd @@ -0,0 +1,70 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# cython: language_level = 3 + +from pyarrow.lib cimport * +from pyarrow.includes.common cimport * +from pyarrow.includes.libarrow cimport * + +cdef class UdfContext(_Weakrefable): + cdef: + CUdfContext c_context + + cdef void init(self, const CUdfContext& c_context) + + +cdef class FunctionOptions(_Weakrefable): + cdef: + shared_ptr[CFunctionOptions] wrapped + + cdef const CFunctionOptions* get_options(self) except NULL + cdef void init(self, const shared_ptr[CFunctionOptions]& sp) + + cdef inline shared_ptr[CFunctionOptions] unwrap(self) + + +cdef class _SortOptions(FunctionOptions): + pass + + +cdef CExpression _bind(Expression filter, Schema schema) except * + + +cdef class Expression(_Weakrefable): + + cdef: + CExpression expr + + cdef void init(self, const CExpression& sp) + + @staticmethod + cdef wrap(const CExpression& sp) + + cdef inline CExpression unwrap(self) + + @staticmethod + cdef Expression _expr_or_scalar(object expr) + + +cdef CExpression _true + +cdef CFieldRef _ensure_field_ref(value) except * + +cdef CSortOrder unwrap_sort_order(order) except * + +cdef CNullPlacement unwrap_null_placement(null_placement) except * diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/_csv.pyx b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/_csv.pyx new file mode 100644 index 0000000000000000000000000000000000000000..508488c0c3b3c3bcd2d2157f57f625b1e5b92c2e --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/_csv.pyx @@ -0,0 +1,1542 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# cython: profile=False +# distutils: language = c++ +# cython: language_level = 3 + +from cython.operator cimport dereference as deref + +from collections import namedtuple +from collections.abc import Mapping + +from pyarrow.includes.common cimport * +from pyarrow.includes.libarrow cimport * +from pyarrow.includes.libarrow_python cimport * +from pyarrow.lib cimport (check_status, Field, MemoryPool, Schema, + RecordBatchReader, ensure_type, + maybe_unbox_memory_pool, get_input_stream, + get_writer, native_transcoding_input_stream, + pyarrow_unwrap_batch, pyarrow_unwrap_schema, + pyarrow_unwrap_table, pyarrow_wrap_schema, + pyarrow_wrap_table, pyarrow_wrap_data_type, + pyarrow_unwrap_data_type, Table, RecordBatch, + StopToken, _CRecordBatchWriter) +from pyarrow.lib import frombytes, tobytes, SignalStopHandler + + +cdef unsigned char _single_char(s) except 0: + val = ord(s) + if val == 0 or val > 127: + raise ValueError("Expecting an ASCII character") + return val + + +_InvalidRow = namedtuple( + "_InvalidRow", ("expected_columns", "actual_columns", "number", "text"), + module=__name__) + + +class InvalidRow(_InvalidRow): + """ + Description of an invalid row in a CSV file. + + Parameters + ---------- + expected_columns : int + The expected number of columns in the row. + actual_columns : int + The actual number of columns in the row. + number : int or None + The physical row number if known, otherwise None. + text : str + The contents of the row. + """ + __slots__ = () + + +cdef CInvalidRowResult _handle_invalid_row( + handler, const CCSVInvalidRow& c_row) except CInvalidRowResult_Error: + # A negative row number means undetermined (because of parallel reading) + row_number = c_row.number if c_row.number >= 0 else None + row = InvalidRow(c_row.expected_columns, c_row.actual_columns, + row_number, frombytes( c_row.text)) + result = handler(row) + if result == 'error': + return CInvalidRowResult_Error + elif result == 'skip': + return CInvalidRowResult_Skip + else: + raise ValueError("Invalid return value for invalid row handler: " + f"expected 'error' or 'skip', got {result!r}") + + +cdef class ReadOptions(_Weakrefable): + """ + Options for reading CSV files. + + Parameters + ---------- + use_threads : bool, optional (default True) + Whether to use multiple threads to accelerate reading + block_size : int, optional + How much bytes to process at a time from the input stream. + This will determine multi-threading granularity as well as + the size of individual record batches or table chunks. + Minimum valid value for block size is 1 + skip_rows : int, optional (default 0) + The number of rows to skip before the column names (if any) + and the CSV data. + skip_rows_after_names : int, optional (default 0) + The number of rows to skip after the column names. + This number can be larger than the number of rows in one + block, and empty rows are counted. + The order of application is as follows: + - `skip_rows` is applied (if non-zero); + - column names are read (unless `column_names` is set); + - `skip_rows_after_names` is applied (if non-zero). + column_names : list, optional + The column names of the target table. If empty, fall back on + `autogenerate_column_names`. + autogenerate_column_names : bool, optional (default False) + Whether to autogenerate column names if `column_names` is empty. + If true, column names will be of the form "f0", "f1"... + If false, column names will be read from the first CSV row + after `skip_rows`. + encoding : str, optional (default 'utf8') + The character encoding of the CSV data. Columns that cannot + decode using this encoding can still be read as Binary. + + Examples + -------- + + Defining an example data: + + >>> import io + >>> s = "1,2,3\\nFlamingo,2,2022-03-01\\nHorse,4,2022-03-02\\nBrittle stars,5,2022-03-03\\nCentipede,100,2022-03-04" + >>> print(s) + 1,2,3 + Flamingo,2,2022-03-01 + Horse,4,2022-03-02 + Brittle stars,5,2022-03-03 + Centipede,100,2022-03-04 + + Ignore the first numbered row and substitute it with defined + or autogenerated column names: + + >>> from pyarrow import csv + >>> read_options = csv.ReadOptions( + ... column_names=["animals", "n_legs", "entry"], + ... skip_rows=1) + >>> csv.read_csv(io.BytesIO(s.encode()), read_options=read_options) + pyarrow.Table + animals: string + n_legs: int64 + entry: date32[day] + ---- + animals: [["Flamingo","Horse","Brittle stars","Centipede"]] + n_legs: [[2,4,5,100]] + entry: [[2022-03-01,2022-03-02,2022-03-03,2022-03-04]] + + >>> read_options = csv.ReadOptions(autogenerate_column_names=True, + ... skip_rows=1) + >>> csv.read_csv(io.BytesIO(s.encode()), read_options=read_options) + pyarrow.Table + f0: string + f1: int64 + f2: date32[day] + ---- + f0: [["Flamingo","Horse","Brittle stars","Centipede"]] + f1: [[2,4,5,100]] + f2: [[2022-03-01,2022-03-02,2022-03-03,2022-03-04]] + + Remove the first 2 rows of the data: + + >>> read_options = csv.ReadOptions(skip_rows_after_names=2) + >>> csv.read_csv(io.BytesIO(s.encode()), read_options=read_options) + pyarrow.Table + 1: string + 2: int64 + 3: date32[day] + ---- + 1: [["Brittle stars","Centipede"]] + 2: [[5,100]] + 3: [[2022-03-03,2022-03-04]] + """ + + # Avoid mistakingly creating attributes + __slots__ = () + + # __init__() is not called when unpickling, initialize storage here + def __cinit__(self, *argw, **kwargs): + self.options.reset(new CCSVReadOptions(CCSVReadOptions.Defaults())) + + def __init__(self, *, use_threads=None, block_size=None, skip_rows=None, + skip_rows_after_names=None, column_names=None, + autogenerate_column_names=None, encoding='utf8'): + if use_threads is not None: + self.use_threads = use_threads + if block_size is not None: + self.block_size = block_size + if skip_rows is not None: + self.skip_rows = skip_rows + if skip_rows_after_names is not None: + self.skip_rows_after_names = skip_rows_after_names + if column_names is not None: + self.column_names = column_names + if autogenerate_column_names is not None: + self.autogenerate_column_names= autogenerate_column_names + # Python-specific option + self.encoding = encoding + + @property + def use_threads(self): + """ + Whether to use multiple threads to accelerate reading. + """ + return deref(self.options).use_threads + + @use_threads.setter + def use_threads(self, value): + deref(self.options).use_threads = value + + @property + def block_size(self): + """ + How much bytes to process at a time from the input stream. + This will determine multi-threading granularity as well as + the size of individual record batches or table chunks. + """ + return deref(self.options).block_size + + @block_size.setter + def block_size(self, value): + deref(self.options).block_size = value + + @property + def skip_rows(self): + """ + The number of rows to skip before the column names (if any) + and the CSV data. + See `skip_rows_after_names` for interaction description + """ + return deref(self.options).skip_rows + + @skip_rows.setter + def skip_rows(self, value): + deref(self.options).skip_rows = value + + @property + def skip_rows_after_names(self): + """ + The number of rows to skip after the column names. + This number can be larger than the number of rows in one + block, and empty rows are counted. + The order of application is as follows: + - `skip_rows` is applied (if non-zero); + - column names are read (unless `column_names` is set); + - `skip_rows_after_names` is applied (if non-zero). + """ + return deref(self.options).skip_rows_after_names + + @skip_rows_after_names.setter + def skip_rows_after_names(self, value): + deref(self.options).skip_rows_after_names = value + + @property + def column_names(self): + """ + The column names of the target table. If empty, fall back on + `autogenerate_column_names`. + """ + return [frombytes(s) for s in deref(self.options).column_names] + + @column_names.setter + def column_names(self, value): + deref(self.options).column_names.clear() + for item in value: + deref(self.options).column_names.push_back(tobytes(item)) + + @property + def autogenerate_column_names(self): + """ + Whether to autogenerate column names if `column_names` is empty. + If true, column names will be of the form "f0", "f1"... + If false, column names will be read from the first CSV row + after `skip_rows`. + """ + return deref(self.options).autogenerate_column_names + + @autogenerate_column_names.setter + def autogenerate_column_names(self, value): + deref(self.options).autogenerate_column_names = value + + def validate(self): + check_status(deref(self.options).Validate()) + + def equals(self, ReadOptions other): + """ + Parameters + ---------- + other : pyarrow.csv.ReadOptions + + Returns + ------- + bool + """ + return ( + self.use_threads == other.use_threads and + self.block_size == other.block_size and + self.skip_rows == other.skip_rows and + self.skip_rows_after_names == other.skip_rows_after_names and + self.column_names == other.column_names and + self.autogenerate_column_names == + other.autogenerate_column_names and + self.encoding == other.encoding + ) + + @staticmethod + cdef ReadOptions wrap(CCSVReadOptions options): + out = ReadOptions() + out.options.reset(new CCSVReadOptions(move(options))) + out.encoding = 'utf8' # No way to know this + return out + + def __getstate__(self): + return (self.use_threads, self.block_size, self.skip_rows, + self.column_names, self.autogenerate_column_names, + self.encoding, self.skip_rows_after_names) + + def __setstate__(self, state): + (self.use_threads, self.block_size, self.skip_rows, + self.column_names, self.autogenerate_column_names, + self.encoding, self.skip_rows_after_names) = state + + def __eq__(self, other): + try: + return self.equals(other) + except TypeError: + return False + + +cdef class ParseOptions(_Weakrefable): + """ + Options for parsing CSV files. + + Parameters + ---------- + delimiter : 1-character string, optional (default ',') + The character delimiting individual cells in the CSV data. + quote_char : 1-character string or False, optional (default '"') + The character used optionally for quoting CSV values + (False if quoting is not allowed). + double_quote : bool, optional (default True) + Whether two quotes in a quoted CSV value denote a single quote + in the data. + escape_char : 1-character string or False, optional (default False) + The character used optionally for escaping special characters + (False if escaping is not allowed). + newlines_in_values : bool, optional (default False) + Whether newline characters are allowed in CSV values. + Setting this to True reduces the performance of multi-threaded + CSV reading. + ignore_empty_lines : bool, optional (default True) + Whether empty lines are ignored in CSV input. + If False, an empty line is interpreted as containing a single empty + value (assuming a one-column CSV file). + invalid_row_handler : callable, optional (default None) + If not None, this object is called for each CSV row that fails + parsing (because of a mismatching number of columns). + It should accept a single InvalidRow argument and return either + "skip" or "error" depending on the desired outcome. + + Examples + -------- + + Defining an example file from bytes object: + + >>> import io + >>> s = ( + ... "animals;n_legs;entry\\n" + ... "Flamingo;2;2022-03-01\\n" + ... "# Comment here:\\n" + ... "Horse;4;2022-03-02\\n" + ... "Brittle stars;5;2022-03-03\\n" + ... "Centipede;100;2022-03-04" + ... ) + >>> print(s) + animals;n_legs;entry + Flamingo;2;2022-03-01 + # Comment here: + Horse;4;2022-03-02 + Brittle stars;5;2022-03-03 + Centipede;100;2022-03-04 + >>> source = io.BytesIO(s.encode()) + + Read the data from a file skipping rows with comments + and defining the delimiter: + + >>> from pyarrow import csv + >>> def skip_comment(row): + ... if row.text.startswith("# "): + ... return 'skip' + ... else: + ... return 'error' + ... + >>> parse_options = csv.ParseOptions(delimiter=";", invalid_row_handler=skip_comment) + >>> csv.read_csv(source, parse_options=parse_options) + pyarrow.Table + animals: string + n_legs: int64 + entry: date32[day] + ---- + animals: [["Flamingo","Horse","Brittle stars","Centipede"]] + n_legs: [[2,4,5,100]] + entry: [[2022-03-01,2022-03-02,2022-03-03,2022-03-04]] + """ + __slots__ = () + + def __cinit__(self, *argw, **kwargs): + self._invalid_row_handler = None + self.options.reset(new CCSVParseOptions(CCSVParseOptions.Defaults())) + + def __init__(self, *, delimiter=None, quote_char=None, double_quote=None, + escape_char=None, newlines_in_values=None, + ignore_empty_lines=None, invalid_row_handler=None): + if delimiter is not None: + self.delimiter = delimiter + if quote_char is not None: + self.quote_char = quote_char + if double_quote is not None: + self.double_quote = double_quote + if escape_char is not None: + self.escape_char = escape_char + if newlines_in_values is not None: + self.newlines_in_values = newlines_in_values + if ignore_empty_lines is not None: + self.ignore_empty_lines = ignore_empty_lines + if invalid_row_handler is not None: + self.invalid_row_handler = invalid_row_handler + + @property + def delimiter(self): + """ + The character delimiting individual cells in the CSV data. + """ + return chr(deref(self.options).delimiter) + + @delimiter.setter + def delimiter(self, value): + deref(self.options).delimiter = _single_char(value) + + @property + def quote_char(self): + """ + The character used optionally for quoting CSV values + (False if quoting is not allowed). + """ + if deref(self.options).quoting: + return chr(deref(self.options).quote_char) + else: + return False + + @quote_char.setter + def quote_char(self, value): + if value is False: + deref(self.options).quoting = False + else: + deref(self.options).quote_char = _single_char(value) + deref(self.options).quoting = True + + @property + def double_quote(self): + """ + Whether two quotes in a quoted CSV value denote a single quote + in the data. + """ + return deref(self.options).double_quote + + @double_quote.setter + def double_quote(self, value): + deref(self.options).double_quote = value + + @property + def escape_char(self): + """ + The character used optionally for escaping special characters + (False if escaping is not allowed). + """ + if deref(self.options).escaping: + return chr(deref(self.options).escape_char) + else: + return False + + @escape_char.setter + def escape_char(self, value): + if value is False: + deref(self.options).escaping = False + else: + deref(self.options).escape_char = _single_char(value) + deref(self.options).escaping = True + + @property + def newlines_in_values(self): + """ + Whether newline characters are allowed in CSV values. + Setting this to True reduces the performance of multi-threaded + CSV reading. + """ + return deref(self.options).newlines_in_values + + @newlines_in_values.setter + def newlines_in_values(self, value): + deref(self.options).newlines_in_values = value + + @property + def ignore_empty_lines(self): + """ + Whether empty lines are ignored in CSV input. + If False, an empty line is interpreted as containing a single empty + value (assuming a one-column CSV file). + """ + return deref(self.options).ignore_empty_lines + + @property + def invalid_row_handler(self): + """ + Optional handler for invalid rows. + + If not None, this object is called for each CSV row that fails + parsing (because of a mismatching number of columns). + It should accept a single InvalidRow argument and return either + "skip" or "error" depending on the desired outcome. + """ + return self._invalid_row_handler + + @invalid_row_handler.setter + def invalid_row_handler(self, value): + if value is not None and not callable(value): + raise TypeError("Expected callable or None, " + f"got instance of {type(value)!r}") + self._invalid_row_handler = value + deref(self.options).invalid_row_handler = MakeInvalidRowHandler( + &_handle_invalid_row, value) + + @ignore_empty_lines.setter + def ignore_empty_lines(self, value): + deref(self.options).ignore_empty_lines = value + + def validate(self): + check_status(deref(self.options).Validate()) + + def equals(self, ParseOptions other): + """ + Parameters + ---------- + other : pyarrow.csv.ParseOptions + + Returns + ------- + bool + """ + return ( + self.delimiter == other.delimiter and + self.quote_char == other.quote_char and + self.double_quote == other.double_quote and + self.escape_char == other.escape_char and + self.newlines_in_values == other.newlines_in_values and + self.ignore_empty_lines == other.ignore_empty_lines and + self._invalid_row_handler == other._invalid_row_handler + ) + + @staticmethod + cdef ParseOptions wrap(CCSVParseOptions options): + out = ParseOptions() + out.options.reset(new CCSVParseOptions(move(options))) + return out + + def __getstate__(self): + return (self.delimiter, self.quote_char, self.double_quote, + self.escape_char, self.newlines_in_values, + self.ignore_empty_lines, self.invalid_row_handler) + + def __setstate__(self, state): + (self.delimiter, self.quote_char, self.double_quote, + self.escape_char, self.newlines_in_values, + self.ignore_empty_lines, self.invalid_row_handler) = state + + def __eq__(self, other): + try: + return self.equals(other) + except TypeError: + return False + + +cdef class _ISO8601(_Weakrefable): + """ + A special object indicating ISO-8601 parsing. + """ + __slots__ = () + + def __str__(self): + return 'ISO8601' + + def __eq__(self, other): + return isinstance(other, _ISO8601) + + +ISO8601 = _ISO8601() + + +cdef class ConvertOptions(_Weakrefable): + """ + Options for converting CSV data. + + Parameters + ---------- + check_utf8 : bool, optional (default True) + Whether to check UTF8 validity of string columns. + column_types : pyarrow.Schema or dict, optional + Explicitly map column names to column types. Passing this argument + disables type inference on the defined columns. + null_values : list, optional + A sequence of strings that denote nulls in the data + (defaults are appropriate in most cases). Note that by default, + string columns are not checked for null values. To enable + null checking for those, specify ``strings_can_be_null=True``. + true_values : list, optional + A sequence of strings that denote true booleans in the data + (defaults are appropriate in most cases). + false_values : list, optional + A sequence of strings that denote false booleans in the data + (defaults are appropriate in most cases). + decimal_point : 1-character string, optional (default '.') + The character used as decimal point in floating-point and decimal + data. + strings_can_be_null : bool, optional (default False) + Whether string / binary columns can have null values. + If true, then strings in null_values are considered null for + string columns. + If false, then all strings are valid string values. + quoted_strings_can_be_null : bool, optional (default True) + Whether quoted values can be null. + If true, then strings in "null_values" are also considered null + when they appear quoted in the CSV file. Otherwise, quoted values + are never considered null. + include_columns : list, optional + The names of columns to include in the Table. + If empty, the Table will include all columns from the CSV file. + If not empty, only these columns will be included, in this order. + include_missing_columns : bool, optional (default False) + If false, columns in `include_columns` but not in the CSV file will + error out. + If true, columns in `include_columns` but not in the CSV file will + produce a column of nulls (whose type is selected using + `column_types`, or null by default). + This option is ignored if `include_columns` is empty. + auto_dict_encode : bool, optional (default False) + Whether to try to automatically dict-encode string / binary data. + If true, then when type inference detects a string or binary column, + it it dict-encoded up to `auto_dict_max_cardinality` distinct values + (per chunk), after which it switches to regular encoding. + This setting is ignored for non-inferred columns (those in + `column_types`). + auto_dict_max_cardinality : int, optional + The maximum dictionary cardinality for `auto_dict_encode`. + This value is per chunk. + timestamp_parsers : list, optional + A sequence of strptime()-compatible format strings, tried in order + when attempting to infer or convert timestamp values (the special + value ISO8601() can also be given). By default, a fast built-in + ISO-8601 parser is used. + + Examples + -------- + + Defining an example data: + + >>> import io + >>> s = ( + ... "animals,n_legs,entry,fast\\n" + ... "Flamingo,2,01/03/2022,Yes\\n" + ... "Horse,4,02/03/2022,Yes\\n" + ... "Brittle stars,5,03/03/2022,No\\n" + ... "Centipede,100,04/03/2022,No\\n" + ... ",6,05/03/2022," + ... ) + >>> print(s) + animals,n_legs,entry,fast + Flamingo,2,01/03/2022,Yes + Horse,4,02/03/2022,Yes + Brittle stars,5,03/03/2022,No + Centipede,100,04/03/2022,No + ,6,05/03/2022, + + Change the type of a column: + + >>> import pyarrow as pa + >>> from pyarrow import csv + >>> convert_options = csv.ConvertOptions(column_types={"n_legs": pa.float64()}) + >>> csv.read_csv(io.BytesIO(s.encode()), convert_options=convert_options) + pyarrow.Table + animals: string + n_legs: double + entry: string + fast: string + ---- + animals: [["Flamingo","Horse","Brittle stars","Centipede",""]] + n_legs: [[2,4,5,100,6]] + entry: [["01/03/2022","02/03/2022","03/03/2022","04/03/2022","05/03/2022"]] + fast: [["Yes","Yes","No","No",""]] + + Define a date parsing format to get a timestamp type column + (in case dates are not in ISO format and not converted by default): + + >>> convert_options = csv.ConvertOptions( + ... timestamp_parsers=["%m/%d/%Y", "%m-%d-%Y"]) + >>> csv.read_csv(io.BytesIO(s.encode()), convert_options=convert_options) + pyarrow.Table + animals: string + n_legs: int64 + entry: timestamp[s] + fast: string + ---- + animals: [["Flamingo","Horse","Brittle stars","Centipede",""]] + n_legs: [[2,4,5,100,6]] + entry: [[2022-01-03 00:00:00,2022-02-03 00:00:00,2022-03-03 00:00:00,2022-04-03 00:00:00,2022-05-03 00:00:00]] + fast: [["Yes","Yes","No","No",""]] + + Specify a subset of columns to be read: + + >>> convert_options = csv.ConvertOptions( + ... include_columns=["animals", "n_legs"]) + >>> csv.read_csv(io.BytesIO(s.encode()), convert_options=convert_options) + pyarrow.Table + animals: string + n_legs: int64 + ---- + animals: [["Flamingo","Horse","Brittle stars","Centipede",""]] + n_legs: [[2,4,5,100,6]] + + List additional column to be included as a null typed column: + + >>> convert_options = csv.ConvertOptions( + ... include_columns=["animals", "n_legs", "location"], + ... include_missing_columns=True) + >>> csv.read_csv(io.BytesIO(s.encode()), convert_options=convert_options) + pyarrow.Table + animals: string + n_legs: int64 + location: null + ---- + animals: [["Flamingo","Horse","Brittle stars","Centipede",""]] + n_legs: [[2,4,5,100,6]] + location: [5 nulls] + + Define columns as dictionary type (by default only the + string/binary columns are dictionary encoded): + + >>> convert_options = csv.ConvertOptions( + ... timestamp_parsers=["%m/%d/%Y", "%m-%d-%Y"], + ... auto_dict_encode=True) + >>> csv.read_csv(io.BytesIO(s.encode()), convert_options=convert_options) + pyarrow.Table + animals: dictionary + n_legs: int64 + entry: timestamp[s] + fast: dictionary + ---- + animals: [ -- dictionary: + ["Flamingo","Horse","Brittle stars","Centipede",""] -- indices: + [0,1,2,3,4]] + n_legs: [[2,4,5,100,6]] + entry: [[2022-01-03 00:00:00,2022-02-03 00:00:00,2022-03-03 00:00:00,2022-04-03 00:00:00,2022-05-03 00:00:00]] + fast: [ -- dictionary: + ["Yes","No",""] -- indices: + [0,0,1,1,2]] + + Set upper limit for the number of categories. If the categories + is more than the limit, the conversion to dictionary will not + happen: + + >>> convert_options = csv.ConvertOptions( + ... include_columns=["animals"], + ... auto_dict_encode=True, + ... auto_dict_max_cardinality=2) + >>> csv.read_csv(io.BytesIO(s.encode()), convert_options=convert_options) + pyarrow.Table + animals: string + ---- + animals: [["Flamingo","Horse","Brittle stars","Centipede",""]] + + Set empty strings to missing values: + + >>> convert_options = csv.ConvertOptions(include_columns=["animals", "n_legs"], + ... strings_can_be_null=True) + >>> csv.read_csv(io.BytesIO(s.encode()), convert_options=convert_options) + pyarrow.Table + animals: string + n_legs: int64 + ---- + animals: [["Flamingo","Horse","Brittle stars","Centipede",null]] + n_legs: [[2,4,5,100,6]] + + Define values to be True and False when converting a column + into a bool type: + + >>> convert_options = csv.ConvertOptions( + ... include_columns=["fast"], + ... false_values=["No"], + ... true_values=["Yes"]) + >>> csv.read_csv(io.BytesIO(s.encode()), convert_options=convert_options) + pyarrow.Table + fast: bool + ---- + fast: [[true,true,false,false,null]] + """ + + # Avoid mistakingly creating attributes + __slots__ = () + + def __cinit__(self, *argw, **kwargs): + self.options.reset( + new CCSVConvertOptions(CCSVConvertOptions.Defaults())) + + def __init__(self, *, check_utf8=None, column_types=None, null_values=None, + true_values=None, false_values=None, decimal_point=None, + strings_can_be_null=None, quoted_strings_can_be_null=None, + include_columns=None, include_missing_columns=None, + auto_dict_encode=None, auto_dict_max_cardinality=None, + timestamp_parsers=None): + if check_utf8 is not None: + self.check_utf8 = check_utf8 + if column_types is not None: + self.column_types = column_types + if null_values is not None: + self.null_values = null_values + if true_values is not None: + self.true_values = true_values + if false_values is not None: + self.false_values = false_values + if decimal_point is not None: + self.decimal_point = decimal_point + if strings_can_be_null is not None: + self.strings_can_be_null = strings_can_be_null + if quoted_strings_can_be_null is not None: + self.quoted_strings_can_be_null = quoted_strings_can_be_null + if include_columns is not None: + self.include_columns = include_columns + if include_missing_columns is not None: + self.include_missing_columns = include_missing_columns + if auto_dict_encode is not None: + self.auto_dict_encode = auto_dict_encode + if auto_dict_max_cardinality is not None: + self.auto_dict_max_cardinality = auto_dict_max_cardinality + if timestamp_parsers is not None: + self.timestamp_parsers = timestamp_parsers + + @property + def check_utf8(self): + """ + Whether to check UTF8 validity of string columns. + """ + return deref(self.options).check_utf8 + + @check_utf8.setter + def check_utf8(self, value): + deref(self.options).check_utf8 = value + + @property + def strings_can_be_null(self): + """ + Whether string / binary columns can have null values. + """ + return deref(self.options).strings_can_be_null + + @strings_can_be_null.setter + def strings_can_be_null(self, value): + deref(self.options).strings_can_be_null = value + + @property + def quoted_strings_can_be_null(self): + """ + Whether quoted values can be null. + """ + return deref(self.options).quoted_strings_can_be_null + + @quoted_strings_can_be_null.setter + def quoted_strings_can_be_null(self, value): + deref(self.options).quoted_strings_can_be_null = value + + @property + def column_types(self): + """ + Explicitly map column names to column types. + """ + d = {frombytes(item.first): pyarrow_wrap_data_type(item.second) + for item in deref(self.options).column_types} + return d + + @column_types.setter + def column_types(self, value): + cdef: + shared_ptr[CDataType] typ + + if isinstance(value, Mapping): + value = value.items() + + deref(self.options).column_types.clear() + for item in value: + if isinstance(item, Field): + k = item.name + v = item.type + else: + k, v = item + typ = pyarrow_unwrap_data_type(ensure_type(v)) + assert typ != NULL + deref(self.options).column_types[tobytes(k)] = typ + + @property + def null_values(self): + """ + A sequence of strings that denote nulls in the data. + """ + return [frombytes(x) for x in deref(self.options).null_values] + + @null_values.setter + def null_values(self, value): + deref(self.options).null_values = [tobytes(x) for x in value] + + @property + def true_values(self): + """ + A sequence of strings that denote true booleans in the data. + """ + return [frombytes(x) for x in deref(self.options).true_values] + + @true_values.setter + def true_values(self, value): + deref(self.options).true_values = [tobytes(x) for x in value] + + @property + def false_values(self): + """ + A sequence of strings that denote false booleans in the data. + """ + return [frombytes(x) for x in deref(self.options).false_values] + + @false_values.setter + def false_values(self, value): + deref(self.options).false_values = [tobytes(x) for x in value] + + @property + def decimal_point(self): + """ + The character used as decimal point in floating-point and decimal + data. + """ + return chr(deref(self.options).decimal_point) + + @decimal_point.setter + def decimal_point(self, value): + deref(self.options).decimal_point = _single_char(value) + + @property + def auto_dict_encode(self): + """ + Whether to try to automatically dict-encode string / binary data. + """ + return deref(self.options).auto_dict_encode + + @auto_dict_encode.setter + def auto_dict_encode(self, value): + deref(self.options).auto_dict_encode = value + + @property + def auto_dict_max_cardinality(self): + """ + The maximum dictionary cardinality for `auto_dict_encode`. + + This value is per chunk. + """ + return deref(self.options).auto_dict_max_cardinality + + @auto_dict_max_cardinality.setter + def auto_dict_max_cardinality(self, value): + deref(self.options).auto_dict_max_cardinality = value + + @property + def include_columns(self): + """ + The names of columns to include in the Table. + + If empty, the Table will include all columns from the CSV file. + If not empty, only these columns will be included, in this order. + """ + return [frombytes(s) for s in deref(self.options).include_columns] + + @include_columns.setter + def include_columns(self, value): + deref(self.options).include_columns.clear() + for item in value: + deref(self.options).include_columns.push_back(tobytes(item)) + + @property + def include_missing_columns(self): + """ + If false, columns in `include_columns` but not in the CSV file will + error out. + If true, columns in `include_columns` but not in the CSV file will + produce a null column (whose type is selected using `column_types`, + or null by default). + This option is ignored if `include_columns` is empty. + """ + return deref(self.options).include_missing_columns + + @include_missing_columns.setter + def include_missing_columns(self, value): + deref(self.options).include_missing_columns = value + + @property + def timestamp_parsers(self): + """ + A sequence of strptime()-compatible format strings, tried in order + when attempting to infer or convert timestamp values (the special + value ISO8601() can also be given). By default, a fast built-in + ISO-8601 parser is used. + """ + cdef: + shared_ptr[CTimestampParser] c_parser + c_string kind + + parsers = [] + for c_parser in deref(self.options).timestamp_parsers: + kind = deref(c_parser).kind() + if kind == b'strptime': + parsers.append(frombytes(deref(c_parser).format())) + else: + assert kind == b'iso8601' + parsers.append(ISO8601) + + return parsers + + @timestamp_parsers.setter + def timestamp_parsers(self, value): + cdef: + vector[shared_ptr[CTimestampParser]] c_parsers + + for v in value: + if isinstance(v, str): + c_parsers.push_back(CTimestampParser.MakeStrptime(tobytes(v))) + elif v == ISO8601: + c_parsers.push_back(CTimestampParser.MakeISO8601()) + else: + raise TypeError("Expected list of str or ISO8601 objects") + + deref(self.options).timestamp_parsers = move(c_parsers) + + @staticmethod + cdef ConvertOptions wrap(CCSVConvertOptions options): + out = ConvertOptions() + out.options.reset(new CCSVConvertOptions(move(options))) + return out + + def validate(self): + check_status(deref(self.options).Validate()) + + def equals(self, ConvertOptions other): + """ + Parameters + ---------- + other : pyarrow.csv.ConvertOptions + + Returns + ------- + bool + """ + return ( + self.check_utf8 == other.check_utf8 and + self.column_types == other.column_types and + self.null_values == other.null_values and + self.true_values == other.true_values and + self.false_values == other.false_values and + self.decimal_point == other.decimal_point and + self.timestamp_parsers == other.timestamp_parsers and + self.strings_can_be_null == other.strings_can_be_null and + self.quoted_strings_can_be_null == + other.quoted_strings_can_be_null and + self.auto_dict_encode == other.auto_dict_encode and + self.auto_dict_max_cardinality == + other.auto_dict_max_cardinality and + self.include_columns == other.include_columns and + self.include_missing_columns == other.include_missing_columns + ) + + def __getstate__(self): + return (self.check_utf8, self.column_types, self.null_values, + self.true_values, self.false_values, self.decimal_point, + self.timestamp_parsers, self.strings_can_be_null, + self.quoted_strings_can_be_null, self.auto_dict_encode, + self.auto_dict_max_cardinality, self.include_columns, + self.include_missing_columns) + + def __setstate__(self, state): + (self.check_utf8, self.column_types, self.null_values, + self.true_values, self.false_values, self.decimal_point, + self.timestamp_parsers, self.strings_can_be_null, + self.quoted_strings_can_be_null, self.auto_dict_encode, + self.auto_dict_max_cardinality, self.include_columns, + self.include_missing_columns) = state + + def __eq__(self, other): + try: + return self.equals(other) + except TypeError: + return False + + +cdef _get_reader(input_file, ReadOptions read_options, + shared_ptr[CInputStream]* out): + use_memory_map = False + get_input_stream(input_file, use_memory_map, out) + if read_options is not None: + out[0] = native_transcoding_input_stream(out[0], + read_options.encoding, + 'utf8') + + +cdef _get_read_options(ReadOptions read_options, CCSVReadOptions* out): + if read_options is None: + out[0] = CCSVReadOptions.Defaults() + else: + out[0] = deref(read_options.options) + + +cdef _get_parse_options(ParseOptions parse_options, CCSVParseOptions* out): + if parse_options is None: + out[0] = CCSVParseOptions.Defaults() + else: + out[0] = deref(parse_options.options) + + +cdef _get_convert_options(ConvertOptions convert_options, + CCSVConvertOptions* out): + if convert_options is None: + out[0] = CCSVConvertOptions.Defaults() + else: + out[0] = deref(convert_options.options) + + +cdef class CSVStreamingReader(RecordBatchReader): + """An object that reads record batches incrementally from a CSV file. + + Should not be instantiated directly by user code. + """ + cdef readonly: + Schema schema + + def __init__(self): + raise TypeError("Do not call {}'s constructor directly, " + "use pyarrow.csv.open_csv() instead." + .format(self.__class__.__name__)) + + # Note about cancellation: we cannot create a SignalStopHandler + # by default here, as several CSVStreamingReader instances may be + # created (including by the same thread). Handling cancellation + # would require having the user pass the SignalStopHandler. + # (in addition to solving ARROW-11853) + + cdef _open(self, shared_ptr[CInputStream] stream, + CCSVReadOptions c_read_options, + CCSVParseOptions c_parse_options, + CCSVConvertOptions c_convert_options, + MemoryPool memory_pool): + cdef: + shared_ptr[CSchema] c_schema + CIOContext io_context + + io_context = CIOContext(maybe_unbox_memory_pool(memory_pool)) + + with nogil: + self.reader = GetResultValue( + CCSVStreamingReader.Make( + io_context, stream, + move(c_read_options), move(c_parse_options), + move(c_convert_options))) + c_schema = self.reader.get().schema() + + self.schema = pyarrow_wrap_schema(c_schema) + + +def read_csv(input_file, read_options=None, parse_options=None, + convert_options=None, MemoryPool memory_pool=None): + """ + Read a Table from a stream of CSV data. + + Parameters + ---------- + input_file : string, path or file-like object + The location of CSV data. If a string or path, and if it ends + with a recognized compressed file extension (e.g. ".gz" or ".bz2"), + the data is automatically decompressed when reading. + read_options : pyarrow.csv.ReadOptions, optional + Options for the CSV reader (see pyarrow.csv.ReadOptions constructor + for defaults) + parse_options : pyarrow.csv.ParseOptions, optional + Options for the CSV parser + (see pyarrow.csv.ParseOptions constructor for defaults) + convert_options : pyarrow.csv.ConvertOptions, optional + Options for converting CSV data + (see pyarrow.csv.ConvertOptions constructor for defaults) + memory_pool : MemoryPool, optional + Pool to allocate Table memory from + + Returns + ------- + :class:`pyarrow.Table` + Contents of the CSV file as a in-memory table. + + Examples + -------- + + Defining an example file from bytes object: + + >>> import io + >>> s = ( + ... "animals,n_legs,entry\\n" + ... "Flamingo,2,2022-03-01\\n" + ... "Horse,4,2022-03-02\\n" + ... "Brittle stars,5,2022-03-03\\n" + ... "Centipede,100,2022-03-04" + ... ) + >>> print(s) + animals,n_legs,entry + Flamingo,2,2022-03-01 + Horse,4,2022-03-02 + Brittle stars,5,2022-03-03 + Centipede,100,2022-03-04 + >>> source = io.BytesIO(s.encode()) + + Reading from the file + + >>> from pyarrow import csv + >>> csv.read_csv(source) + pyarrow.Table + animals: string + n_legs: int64 + entry: date32[day] + ---- + animals: [["Flamingo","Horse","Brittle stars","Centipede"]] + n_legs: [[2,4,5,100]] + entry: [[2022-03-01,2022-03-02,2022-03-03,2022-03-04]] + """ + cdef: + shared_ptr[CInputStream] stream + CCSVReadOptions c_read_options + CCSVParseOptions c_parse_options + CCSVConvertOptions c_convert_options + CIOContext io_context + SharedPtrNoGIL[CCSVReader] reader + shared_ptr[CTable] table + + _get_reader(input_file, read_options, &stream) + _get_read_options(read_options, &c_read_options) + _get_parse_options(parse_options, &c_parse_options) + _get_convert_options(convert_options, &c_convert_options) + + with SignalStopHandler() as stop_handler: + io_context = CIOContext( + maybe_unbox_memory_pool(memory_pool), + ( stop_handler.stop_token).stop_token) + reader = GetResultValue(CCSVReader.Make( + io_context, stream, + c_read_options, c_parse_options, c_convert_options)) + + with nogil: + table = GetResultValue(reader.get().Read()) + + return pyarrow_wrap_table(table) + + +def open_csv(input_file, read_options=None, parse_options=None, + convert_options=None, MemoryPool memory_pool=None): + """ + Open a streaming reader of CSV data. + + Reading using this function is always single-threaded. + + Parameters + ---------- + input_file : string, path or file-like object + The location of CSV data. If a string or path, and if it ends + with a recognized compressed file extension (e.g. ".gz" or ".bz2"), + the data is automatically decompressed when reading. + read_options : pyarrow.csv.ReadOptions, optional + Options for the CSV reader (see pyarrow.csv.ReadOptions constructor + for defaults) + parse_options : pyarrow.csv.ParseOptions, optional + Options for the CSV parser + (see pyarrow.csv.ParseOptions constructor for defaults) + convert_options : pyarrow.csv.ConvertOptions, optional + Options for converting CSV data + (see pyarrow.csv.ConvertOptions constructor for defaults) + memory_pool : MemoryPool, optional + Pool to allocate Table memory from + + Returns + ------- + :class:`pyarrow.csv.CSVStreamingReader` + """ + cdef: + shared_ptr[CInputStream] stream + CCSVReadOptions c_read_options + CCSVParseOptions c_parse_options + CCSVConvertOptions c_convert_options + CSVStreamingReader reader + + _get_reader(input_file, read_options, &stream) + _get_read_options(read_options, &c_read_options) + _get_parse_options(parse_options, &c_parse_options) + _get_convert_options(convert_options, &c_convert_options) + + reader = CSVStreamingReader.__new__(CSVStreamingReader) + reader._open(stream, move(c_read_options), move(c_parse_options), + move(c_convert_options), memory_pool) + return reader + + +def _raise_invalid_function_option(value, description, *, + exception_class=ValueError): + raise exception_class(f"\"{value}\" is not a valid {description}") + + +cdef CQuotingStyle unwrap_quoting_style(quoting_style) except *: + if quoting_style == "needed": + return CQuotingStyle_Needed + elif quoting_style == "all_valid": + return CQuotingStyle_AllValid + elif quoting_style == "none": + return CQuotingStyle_None + _raise_invalid_function_option(quoting_style, "quoting style") + + +cdef wrap_quoting_style(quoting_style): + if quoting_style == CQuotingStyle_Needed: + return 'needed' + elif quoting_style == CQuotingStyle_AllValid: + return 'all_valid' + elif quoting_style == CQuotingStyle_None: + return 'none' + + +cdef class WriteOptions(_Weakrefable): + """ + Options for writing CSV files. + + Parameters + ---------- + include_header : bool, optional (default True) + Whether to write an initial header line with column names + batch_size : int, optional (default 1024) + How many rows to process together when converting and writing + CSV data + delimiter : 1-character string, optional (default ",") + The character delimiting individual cells in the CSV data. + quoting_style : str, optional (default "needed") + Whether to quote values, and if so, which quoting style to use. + The following values are accepted: + + - "needed" (default): only enclose values in quotes when needed. + - "all_valid": enclose all valid values in quotes; nulls are not quoted. + - "none": do not enclose any values in quotes; values containing + special characters (such as quotes, cell delimiters or line endings) + will raise an error. + """ + + # Avoid mistakingly creating attributes + __slots__ = () + + def __init__(self, *, include_header=None, batch_size=None, + delimiter=None, quoting_style=None): + self.options.reset(new CCSVWriteOptions(CCSVWriteOptions.Defaults())) + if include_header is not None: + self.include_header = include_header + if batch_size is not None: + self.batch_size = batch_size + if delimiter is not None: + self.delimiter = delimiter + if quoting_style is not None: + self.quoting_style = quoting_style + + @property + def include_header(self): + """ + Whether to write an initial header line with column names. + """ + return deref(self.options).include_header + + @include_header.setter + def include_header(self, value): + deref(self.options).include_header = value + + @property + def batch_size(self): + """ + How many rows to process together when converting and writing + CSV data. + """ + return deref(self.options).batch_size + + @batch_size.setter + def batch_size(self, value): + deref(self.options).batch_size = value + + @property + def delimiter(self): + """ + The character delimiting individual cells in the CSV data. + """ + return chr(deref(self.options).delimiter) + + @delimiter.setter + def delimiter(self, value): + deref(self.options).delimiter = _single_char(value) + + @property + def quoting_style(self): + """ + Whether to quote values, and if so, which quoting style to use. + The following values are accepted: + + - "needed" (default): only enclose values in quotes when needed. + - "all_valid": enclose all valid values in quotes; nulls are not quoted. + - "none": do not enclose any values in quotes; values containing + special characters (such as quotes, cell delimiters or line endings) + will raise an error. + """ + return wrap_quoting_style(deref(self.options).quoting_style) + + @quoting_style.setter + def quoting_style(self, value): + deref(self.options).quoting_style = unwrap_quoting_style(value) + + @staticmethod + cdef WriteOptions wrap(CCSVWriteOptions options): + out = WriteOptions() + out.options.reset(new CCSVWriteOptions(move(options))) + return out + + def validate(self): + check_status(self.options.get().Validate()) + + +cdef _get_write_options(WriteOptions write_options, CCSVWriteOptions* out): + if write_options is None: + out[0] = CCSVWriteOptions.Defaults() + else: + out[0] = deref(write_options.options) + + +def write_csv(data, output_file, write_options=None, + MemoryPool memory_pool=None): + """ + Write record batch or table to a CSV file. + + Parameters + ---------- + data : pyarrow.RecordBatch or pyarrow.Table + The data to write. + output_file : string, path, pyarrow.NativeFile, or file-like object + The location where to write the CSV data. + write_options : pyarrow.csv.WriteOptions + Options to configure writing the CSV data. + memory_pool : MemoryPool, optional + Pool for temporary allocations. + + Examples + -------- + + >>> import pyarrow as pa + >>> from pyarrow import csv + + >>> legs = pa.array([2, 4, 5, 100]) + >>> animals = pa.array(["Flamingo", "Horse", "Brittle stars", "Centipede"]) + >>> entry_date = pa.array(["01/03/2022", "02/03/2022", + ... "03/03/2022", "04/03/2022"]) + >>> table = pa.table([animals, legs, entry_date], + ... names=["animals", "n_legs", "entry"]) + + >>> csv.write_csv(table, "animals.csv") + + >>> write_options = csv.WriteOptions(include_header=False) + >>> csv.write_csv(table, "animals.csv", write_options=write_options) + + >>> write_options = csv.WriteOptions(delimiter=";") + >>> csv.write_csv(table, "animals.csv", write_options=write_options) + """ + cdef: + shared_ptr[COutputStream] stream + CCSVWriteOptions c_write_options + CMemoryPool* c_memory_pool + CRecordBatch* batch + CTable* table + _get_write_options(write_options, &c_write_options) + + get_writer(output_file, &stream) + c_memory_pool = maybe_unbox_memory_pool(memory_pool) + c_write_options.io_context = CIOContext(c_memory_pool) + if isinstance(data, RecordBatch): + batch = pyarrow_unwrap_batch(data).get() + with nogil: + check_status(WriteCSV(deref(batch), c_write_options, stream.get())) + elif isinstance(data, Table): + table = pyarrow_unwrap_table(data).get() + with nogil: + check_status(WriteCSV(deref(table), c_write_options, stream.get())) + else: + raise TypeError(f"Expected Table or RecordBatch, got '{type(data)}'") + + +cdef class CSVWriter(_CRecordBatchWriter): + """ + Writer to create a CSV file. + + Parameters + ---------- + sink : str, path, pyarrow.OutputStream or file-like object + The location where to write the CSV data. + schema : pyarrow.Schema + The schema of the data to be written. + write_options : pyarrow.csv.WriteOptions + Options to configure writing the CSV data. + memory_pool : MemoryPool, optional + Pool for temporary allocations. + """ + + def __init__(self, sink, Schema schema, *, + WriteOptions write_options=None, MemoryPool memory_pool=None): + cdef: + shared_ptr[COutputStream] c_stream + shared_ptr[CSchema] c_schema = pyarrow_unwrap_schema(schema) + CCSVWriteOptions c_write_options + CMemoryPool* c_memory_pool = maybe_unbox_memory_pool(memory_pool) + _get_write_options(write_options, &c_write_options) + c_write_options.io_context = CIOContext(c_memory_pool) + get_writer(sink, &c_stream) + with nogil: + self.writer = GetResultValue(MakeCSVWriter( + c_stream, c_schema, c_write_options)) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/_dataset.pyx b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/_dataset.pyx new file mode 100644 index 0000000000000000000000000000000000000000..fd50215cee9ae39b5d4d64582e9d50a3d9b366b9 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/_dataset.pyx @@ -0,0 +1,4118 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# cython: language_level = 3 + +"""Dataset is currently unstable. APIs subject to change without notice.""" + +from cython.operator cimport dereference as deref + +import codecs +import collections +from libcpp cimport bool + +import pyarrow as pa +from pyarrow.lib cimport * +from pyarrow.lib import ArrowTypeError, frombytes, tobytes, _pac +from pyarrow.includes.libarrow_dataset cimport * +from pyarrow._acero cimport ExecNodeOptions +from pyarrow._compute cimport Expression, _bind +from pyarrow._compute import _forbid_instantiation +from pyarrow._fs cimport FileSystem, FileSelector, FileInfo +from pyarrow._csv cimport ( + ConvertOptions, ParseOptions, ReadOptions, WriteOptions) +from pyarrow.util import _is_iterable, _is_path_like, _stringify_path +from pyarrow._json cimport ParseOptions as JsonParseOptions +from pyarrow._json cimport ReadOptions as JsonReadOptions + +try: + import pyarrow.substrait as pa_substrait +except ImportError: + pa_substrait = None + + +_DEFAULT_BATCH_SIZE = 2**17 +_DEFAULT_BATCH_READAHEAD = 16 +_DEFAULT_FRAGMENT_READAHEAD = 4 + + +# Initialise support for Datasets in ExecPlan +Initialize() + + +_orc_fileformat = None +_orc_imported = False + + +def _get_orc_fileformat(): + """ + Import OrcFileFormat on first usage (to avoid circular import issue + when `pyarrow._dataset_orc` would be imported first) + """ + global _orc_fileformat + global _orc_imported + if not _orc_imported: + try: + from pyarrow._dataset_orc import OrcFileFormat + _orc_fileformat = OrcFileFormat + except ImportError as e: + _orc_fileformat = None + finally: + _orc_imported = True + return _orc_fileformat + + +_dataset_pq = False + + +def _get_parquet_classes(): + """ + Import Parquet class files on first usage (to avoid circular import issue + when `pyarrow._dataset_parquet` would be imported first) + """ + global _dataset_pq + if _dataset_pq is False: + try: + import pyarrow._dataset_parquet as _dataset_pq + except ImportError: + _dataset_pq = None + + +def _get_parquet_symbol(name): + """ + Get a symbol from pyarrow.parquet if the latter is importable, otherwise + return None. + """ + _get_parquet_classes() + return _dataset_pq and getattr(_dataset_pq, name) + + +cdef CFileSource _make_file_source(object file, FileSystem filesystem=None, object file_size=None): + + cdef: + CFileSource c_source + shared_ptr[CFileSystem] c_filesystem + CFileInfo c_info + c_string c_path + shared_ptr[CRandomAccessFile] c_file + shared_ptr[CBuffer] c_buffer + int64_t c_size + + if isinstance(file, Buffer): + c_buffer = pyarrow_unwrap_buffer(file) + c_source = CFileSource(move(c_buffer)) + elif _is_path_like(file): + if filesystem is None: + raise ValueError("cannot construct a FileSource from " + "a path without a FileSystem") + c_filesystem = filesystem.unwrap() + c_path = tobytes(_stringify_path(file)) + + if file_size is not None: + c_size = file_size + c_info = FileInfo(c_path, size=c_size).unwrap() + c_source = CFileSource(move(c_info), move(c_filesystem)) + else: + c_source = CFileSource(move(c_path), move(c_filesystem)) + elif hasattr(file, 'read'): + # Optimistically hope this is file-like + c_file = get_native_file(file, False).get_random_access_file() + c_source = CFileSource(move(c_file)) + + else: + raise TypeError("cannot construct a FileSource " + "from " + str(file)) + + return c_source + + +cdef CSegmentEncoding _get_segment_encoding(str segment_encoding): + if segment_encoding == "none": + return CSegmentEncoding_None + elif segment_encoding == "uri": + return CSegmentEncoding_Uri + raise ValueError(f"Unknown segment encoding: {segment_encoding}") + + +cdef str _wrap_segment_encoding(CSegmentEncoding segment_encoding): + if segment_encoding == CSegmentEncoding_None: + return "none" + elif segment_encoding == CSegmentEncoding_Uri: + return "uri" + raise ValueError("Unknown segment encoding") + + +cdef Expression _true = Expression._scalar(True) + + +cdef class Dataset(_Weakrefable): + """ + Collection of data fragments and potentially child datasets. + + Arrow Datasets allow you to query against data that has been split across + multiple files. This sharding of data may indicate partitioning, which + can accelerate queries that only touch some partitions (files). + """ + + def __init__(self): + _forbid_instantiation(self.__class__) + + cdef void init(self, const shared_ptr[CDataset]& sp): + self.wrapped = sp + self.dataset = sp.get() + self._scan_options = dict() + + @staticmethod + cdef wrap(const shared_ptr[CDataset]& sp): + type_name = frombytes(sp.get().type_name()) + + classes = { + 'union': UnionDataset, + 'filesystem': FileSystemDataset, + 'in-memory': InMemoryDataset, + } + + class_ = classes.get(type_name, None) + if class_ is None: + raise TypeError(type_name) + + cdef Dataset self = class_.__new__(class_) + self.init(sp) + return self + + cdef shared_ptr[CDataset] unwrap(self) nogil: + return self.wrapped + + @property + def partition_expression(self): + """ + An Expression which evaluates to true for all data viewed by this + Dataset. + """ + return Expression.wrap(self.dataset.partition_expression()) + + def replace_schema(self, Schema schema not None): + """ + Return a copy of this Dataset with a different schema. + + The copy will view the same Fragments. If the new schema is not + compatible with the original dataset's schema then an error will + be raised. + + Parameters + ---------- + schema : Schema + The new dataset schema. + """ + cdef shared_ptr[CDataset] copy = GetResultValue( + self.dataset.ReplaceSchema(pyarrow_unwrap_schema(schema)) + ) + + d = Dataset.wrap(move(copy)) + if self._scan_options: + # Preserve scan options if set. + d._scan_options = self._scan_options.copy() + return d + + def get_fragments(self, Expression filter=None): + """Returns an iterator over the fragments in this dataset. + + Parameters + ---------- + filter : Expression, default None + Return fragments matching the optional filter, either using the + partition_expression or internal information like Parquet's + statistics. + + Returns + ------- + fragments : iterator of Fragment + """ + if self._scan_options.get("filter") is not None: + # Accessing fragments of a filtered dataset is not supported. + # It would be unclear if you wanted to filter the fragments + # or the rows in those fragments. + raise ValueError( + "Retrieving fragments of a filtered or projected " + "dataset is not allowed. Remove the filtering." + ) + + return self._get_fragments(filter) + + def _get_fragments(self, Expression filter): + cdef: + CExpression c_filter + + if filter is None: + c_fragments = move(GetResultValue(self.dataset.GetFragments())) + else: + c_filter = _bind(filter, self.schema) + c_fragments = move(GetResultValue( + self.dataset.GetFragments(c_filter))) + + for maybe_fragment in c_fragments: + yield Fragment.wrap(GetResultValue(move(maybe_fragment))) + + def _scanner_options(self, options): + """Returns the default options to create a new Scanner. + + This is automatically invoked by :meth:`Dataset.scanner` + and there is no need to use it. + """ + new_options = options.copy() + + # at the moment only support filter + requested_filter = options.get("filter") + if pa_substrait and isinstance(requested_filter, pa_substrait.BoundExpressions): + expressions = list(requested_filter.expressions.values()) + if len(expressions) != 1: + raise ValueError( + "Only one BoundExpressions with a single expression are supported") + new_options["filter"] = requested_filter = expressions[0] + + current_filter = self._scan_options.get("filter") + if requested_filter is not None and current_filter is not None: + new_options["filter"] = current_filter & requested_filter + elif current_filter is not None: + new_options["filter"] = current_filter + + return new_options + + def scanner(self, + object columns=None, + object filter=None, + int batch_size=_DEFAULT_BATCH_SIZE, + int batch_readahead=_DEFAULT_BATCH_READAHEAD, + int fragment_readahead=_DEFAULT_FRAGMENT_READAHEAD, + FragmentScanOptions fragment_scan_options=None, + bint use_threads=True, + MemoryPool memory_pool=None): + """ + Build a scan operation against the dataset. + + Data is not loaded immediately. Instead, this produces a Scanner, + which exposes further operations (e.g. loading all data as a + table, counting rows). + + See the :meth:`Scanner.from_dataset` method for further information. + + Parameters + ---------- + columns : list of str, default None + The columns to project. This can be a list of column names to + include (order and duplicates will be preserved), or a dictionary + with {new_column_name: expression} values for more advanced + projections. + + The list of columns or expressions may use the special fields + `__batch_index` (the index of the batch within the fragment), + `__fragment_index` (the index of the fragment within the dataset), + `__last_in_fragment` (whether the batch is last in fragment), and + `__filename` (the name of the source file or a description of the + source fragment). + + The columns will be passed down to Datasets and corresponding data + fragments to avoid loading, copying, and deserializing columns + that will not be required further down the compute chain. + By default all of the available columns are projected. Raises + an exception if any of the referenced column names does not exist + in the dataset's Schema. + filter : Expression, default None + Scan will return only the rows matching the filter. + If possible the predicate will be pushed down to exploit the + partition information or internal metadata found in the data + source, e.g. Parquet statistics. Otherwise filters the loaded + RecordBatches before yielding them. + batch_size : int, default 131_072 + The maximum row count for scanned record batches. If scanned + record batches are overflowing memory then this method can be + called to reduce their size. + batch_readahead : int, default 16 + The number of batches to read ahead in a file. This might not work + for all file formats. Increasing this number will increase + RAM usage but could also improve IO utilization. + fragment_readahead : int, default 4 + The number of files to read ahead. Increasing this number will increase + RAM usage but could also improve IO utilization. + fragment_scan_options : FragmentScanOptions, default None + Options specific to a particular scan and fragment type, which + can change between different scans of the same dataset. + use_threads : bool, default True + If enabled, then maximum parallelism will be used determined by + the number of available CPU cores. + memory_pool : MemoryPool, default None + For memory allocations, if required. If not specified, uses the + default pool. + + Returns + ------- + scanner : Scanner + + Examples + -------- + >>> import pyarrow as pa + >>> table = pa.table({'year': [2020, 2022, 2021, 2022, 2019, 2021], + ... 'n_legs': [2, 2, 4, 4, 5, 100], + ... 'animal': ["Flamingo", "Parrot", "Dog", "Horse", + ... "Brittle stars", "Centipede"]}) + >>> + >>> import pyarrow.parquet as pq + >>> pq.write_table(table, "dataset_scanner.parquet") + + >>> import pyarrow.dataset as ds + >>> dataset = ds.dataset("dataset_scanner.parquet") + + Selecting a subset of the columns: + + >>> dataset.scanner(columns=["year", "n_legs"]).to_table() + pyarrow.Table + year: int64 + n_legs: int64 + ---- + year: [[2020,2022,2021,2022,2019,2021]] + n_legs: [[2,2,4,4,5,100]] + + Projecting selected columns using an expression: + + >>> dataset.scanner(columns={ + ... "n_legs_uint": ds.field("n_legs").cast("uint8"), + ... }).to_table() + pyarrow.Table + n_legs_uint: uint8 + ---- + n_legs_uint: [[2,2,4,4,5,100]] + + Filtering rows while scanning: + + >>> dataset.scanner(filter=ds.field("year") > 2020).to_table() + pyarrow.Table + year: int64 + n_legs: int64 + animal: string + ---- + year: [[2022,2021,2022,2021]] + n_legs: [[2,4,4,100]] + animal: [["Parrot","Dog","Horse","Centipede"]] + """ + return Scanner.from_dataset( + self, + columns=columns, + filter=filter, + batch_size=batch_size, + batch_readahead=batch_readahead, + fragment_readahead=fragment_readahead, + fragment_scan_options=fragment_scan_options, + use_threads=use_threads, + memory_pool=memory_pool + ) + + def to_batches(self, + object columns=None, + Expression filter=None, + int batch_size=_DEFAULT_BATCH_SIZE, + int batch_readahead=_DEFAULT_BATCH_READAHEAD, + int fragment_readahead=_DEFAULT_FRAGMENT_READAHEAD, + FragmentScanOptions fragment_scan_options=None, + bint use_threads=True, + MemoryPool memory_pool=None): + """ + Read the dataset as materialized record batches. + + Parameters + ---------- + columns : list of str, default None + The columns to project. This can be a list of column names to + include (order and duplicates will be preserved), or a dictionary + with {new_column_name: expression} values for more advanced + projections. + + The list of columns or expressions may use the special fields + `__batch_index` (the index of the batch within the fragment), + `__fragment_index` (the index of the fragment within the dataset), + `__last_in_fragment` (whether the batch is last in fragment), and + `__filename` (the name of the source file or a description of the + source fragment). + + The columns will be passed down to Datasets and corresponding data + fragments to avoid loading, copying, and deserializing columns + that will not be required further down the compute chain. + By default all of the available columns are projected. Raises + an exception if any of the referenced column names does not exist + in the dataset's Schema. + filter : Expression, default None + Scan will return only the rows matching the filter. + If possible the predicate will be pushed down to exploit the + partition information or internal metadata found in the data + source, e.g. Parquet statistics. Otherwise filters the loaded + RecordBatches before yielding them. + batch_size : int, default 131_072 + The maximum row count for scanned record batches. If scanned + record batches are overflowing memory then this method can be + called to reduce their size. + batch_readahead : int, default 16 + The number of batches to read ahead in a file. This might not work + for all file formats. Increasing this number will increase + RAM usage but could also improve IO utilization. + fragment_readahead : int, default 4 + The number of files to read ahead. Increasing this number will increase + RAM usage but could also improve IO utilization. + fragment_scan_options : FragmentScanOptions, default None + Options specific to a particular scan and fragment type, which + can change between different scans of the same dataset. + use_threads : bool, default True + If enabled, then maximum parallelism will be used determined by + the number of available CPU cores. + memory_pool : MemoryPool, default None + For memory allocations, if required. If not specified, uses the + default pool. + + Returns + ------- + record_batches : iterator of RecordBatch + """ + return self.scanner( + columns=columns, + filter=filter, + batch_size=batch_size, + batch_readahead=batch_readahead, + fragment_readahead=fragment_readahead, + fragment_scan_options=fragment_scan_options, + use_threads=use_threads, + memory_pool=memory_pool + ).to_batches() + + def to_table(self, + object columns=None, + Expression filter=None, + int batch_size=_DEFAULT_BATCH_SIZE, + int batch_readahead=_DEFAULT_BATCH_READAHEAD, + int fragment_readahead=_DEFAULT_FRAGMENT_READAHEAD, + FragmentScanOptions fragment_scan_options=None, + bint use_threads=True, + MemoryPool memory_pool=None): + """ + Read the dataset to an Arrow table. + + Note that this method reads all the selected data from the dataset + into memory. + + Parameters + ---------- + columns : list of str, default None + The columns to project. This can be a list of column names to + include (order and duplicates will be preserved), or a dictionary + with {new_column_name: expression} values for more advanced + projections. + + The list of columns or expressions may use the special fields + `__batch_index` (the index of the batch within the fragment), + `__fragment_index` (the index of the fragment within the dataset), + `__last_in_fragment` (whether the batch is last in fragment), and + `__filename` (the name of the source file or a description of the + source fragment). + + The columns will be passed down to Datasets and corresponding data + fragments to avoid loading, copying, and deserializing columns + that will not be required further down the compute chain. + By default all of the available columns are projected. Raises + an exception if any of the referenced column names does not exist + in the dataset's Schema. + filter : Expression, default None + Scan will return only the rows matching the filter. + If possible the predicate will be pushed down to exploit the + partition information or internal metadata found in the data + source, e.g. Parquet statistics. Otherwise filters the loaded + RecordBatches before yielding them. + batch_size : int, default 131_072 + The maximum row count for scanned record batches. If scanned + record batches are overflowing memory then this method can be + called to reduce their size. + batch_readahead : int, default 16 + The number of batches to read ahead in a file. This might not work + for all file formats. Increasing this number will increase + RAM usage but could also improve IO utilization. + fragment_readahead : int, default 4 + The number of files to read ahead. Increasing this number will increase + RAM usage but could also improve IO utilization. + fragment_scan_options : FragmentScanOptions, default None + Options specific to a particular scan and fragment type, which + can change between different scans of the same dataset. + use_threads : bool, default True + If enabled, then maximum parallelism will be used determined by + the number of available CPU cores. + memory_pool : MemoryPool, default None + For memory allocations, if required. If not specified, uses the + default pool. + + Returns + ------- + table : Table + """ + return self.scanner( + columns=columns, + filter=filter, + batch_size=batch_size, + batch_readahead=batch_readahead, + fragment_readahead=fragment_readahead, + fragment_scan_options=fragment_scan_options, + use_threads=use_threads, + memory_pool=memory_pool + ).to_table() + + def take(self, + object indices, + object columns=None, + Expression filter=None, + int batch_size=_DEFAULT_BATCH_SIZE, + int batch_readahead=_DEFAULT_BATCH_READAHEAD, + int fragment_readahead=_DEFAULT_FRAGMENT_READAHEAD, + FragmentScanOptions fragment_scan_options=None, + bint use_threads=True, + MemoryPool memory_pool=None): + """ + Select rows of data by index. + + Parameters + ---------- + indices : Array or array-like + indices of rows to select in the dataset. + columns : list of str, default None + The columns to project. This can be a list of column names to + include (order and duplicates will be preserved), or a dictionary + with {new_column_name: expression} values for more advanced + projections. + + The list of columns or expressions may use the special fields + `__batch_index` (the index of the batch within the fragment), + `__fragment_index` (the index of the fragment within the dataset), + `__last_in_fragment` (whether the batch is last in fragment), and + `__filename` (the name of the source file or a description of the + source fragment). + + The columns will be passed down to Datasets and corresponding data + fragments to avoid loading, copying, and deserializing columns + that will not be required further down the compute chain. + By default all of the available columns are projected. Raises + an exception if any of the referenced column names does not exist + in the dataset's Schema. + filter : Expression, default None + Scan will return only the rows matching the filter. + If possible the predicate will be pushed down to exploit the + partition information or internal metadata found in the data + source, e.g. Parquet statistics. Otherwise filters the loaded + RecordBatches before yielding them. + batch_size : int, default 131_072 + The maximum row count for scanned record batches. If scanned + record batches are overflowing memory then this method can be + called to reduce their size. + batch_readahead : int, default 16 + The number of batches to read ahead in a file. This might not work + for all file formats. Increasing this number will increase + RAM usage but could also improve IO utilization. + fragment_readahead : int, default 4 + The number of files to read ahead. Increasing this number will increase + RAM usage but could also improve IO utilization. + fragment_scan_options : FragmentScanOptions, default None + Options specific to a particular scan and fragment type, which + can change between different scans of the same dataset. + use_threads : bool, default True + If enabled, then maximum parallelism will be used determined by + the number of available CPU cores. + memory_pool : MemoryPool, default None + For memory allocations, if required. If not specified, uses the + default pool. + + Returns + ------- + table : Table + """ + return self.scanner( + columns=columns, + filter=filter, + batch_size=batch_size, + batch_readahead=batch_readahead, + fragment_readahead=fragment_readahead, + fragment_scan_options=fragment_scan_options, + use_threads=use_threads, + memory_pool=memory_pool + ).take(indices) + + def head(self, + int num_rows, + object columns=None, + Expression filter=None, + int batch_size=_DEFAULT_BATCH_SIZE, + int batch_readahead=_DEFAULT_BATCH_READAHEAD, + int fragment_readahead=_DEFAULT_FRAGMENT_READAHEAD, + FragmentScanOptions fragment_scan_options=None, + bint use_threads=True, + MemoryPool memory_pool=None): + """ + Load the first N rows of the dataset. + + Parameters + ---------- + num_rows : int + The number of rows to load. + columns : list of str, default None + The columns to project. This can be a list of column names to + include (order and duplicates will be preserved), or a dictionary + with {new_column_name: expression} values for more advanced + projections. + + The list of columns or expressions may use the special fields + `__batch_index` (the index of the batch within the fragment), + `__fragment_index` (the index of the fragment within the dataset), + `__last_in_fragment` (whether the batch is last in fragment), and + `__filename` (the name of the source file or a description of the + source fragment). + + The columns will be passed down to Datasets and corresponding data + fragments to avoid loading, copying, and deserializing columns + that will not be required further down the compute chain. + By default all of the available columns are projected. Raises + an exception if any of the referenced column names does not exist + in the dataset's Schema. + filter : Expression, default None + Scan will return only the rows matching the filter. + If possible the predicate will be pushed down to exploit the + partition information or internal metadata found in the data + source, e.g. Parquet statistics. Otherwise filters the loaded + RecordBatches before yielding them. + batch_size : int, default 131_072 + The maximum row count for scanned record batches. If scanned + record batches are overflowing memory then this method can be + called to reduce their size. + batch_readahead : int, default 16 + The number of batches to read ahead in a file. This might not work + for all file formats. Increasing this number will increase + RAM usage but could also improve IO utilization. + fragment_readahead : int, default 4 + The number of files to read ahead. Increasing this number will increase + RAM usage but could also improve IO utilization. + fragment_scan_options : FragmentScanOptions, default None + Options specific to a particular scan and fragment type, which + can change between different scans of the same dataset. + use_threads : bool, default True + If enabled, then maximum parallelism will be used determined by + the number of available CPU cores. + memory_pool : MemoryPool, default None + For memory allocations, if required. If not specified, uses the + default pool. + + Returns + ------- + table : Table + """ + return self.scanner( + columns=columns, + filter=filter, + batch_size=batch_size, + batch_readahead=batch_readahead, + fragment_readahead=fragment_readahead, + fragment_scan_options=fragment_scan_options, + use_threads=use_threads, + memory_pool=memory_pool + ).head(num_rows) + + def count_rows(self, + Expression filter=None, + int batch_size=_DEFAULT_BATCH_SIZE, + int batch_readahead=_DEFAULT_BATCH_READAHEAD, + int fragment_readahead=_DEFAULT_FRAGMENT_READAHEAD, + FragmentScanOptions fragment_scan_options=None, + bint use_threads=True, + MemoryPool memory_pool=None): + """ + Count rows matching the scanner filter. + + Parameters + ---------- + filter : Expression, default None + Scan will return only the rows matching the filter. + If possible the predicate will be pushed down to exploit the + partition information or internal metadata found in the data + source, e.g. Parquet statistics. Otherwise filters the loaded + RecordBatches before yielding them. + batch_size : int, default 131_072 + The maximum row count for scanned record batches. If scanned + record batches are overflowing memory then this method can be + called to reduce their size. + batch_readahead : int, default 16 + The number of batches to read ahead in a file. This might not work + for all file formats. Increasing this number will increase + RAM usage but could also improve IO utilization. + fragment_readahead : int, default 4 + The number of files to read ahead. Increasing this number will increase + RAM usage but could also improve IO utilization. + fragment_scan_options : FragmentScanOptions, default None + Options specific to a particular scan and fragment type, which + can change between different scans of the same dataset. + use_threads : bool, default True + If enabled, then maximum parallelism will be used determined by + the number of available CPU cores. + memory_pool : MemoryPool, default None + For memory allocations, if required. If not specified, uses the + default pool. + + Returns + ------- + count : int + """ + return self.scanner( + filter=filter, + batch_size=batch_size, + batch_readahead=batch_readahead, + fragment_readahead=fragment_readahead, + fragment_scan_options=fragment_scan_options, + use_threads=use_threads, + memory_pool=memory_pool + ).count_rows() + + @property + def schema(self): + """The common schema of the full Dataset""" + return pyarrow_wrap_schema(self.dataset.schema()) + + def filter(self, expression not None): + """ + Apply a row filter to the dataset. + + Parameters + ---------- + expression : Expression + The filter that should be applied to the dataset. + + Returns + ------- + Dataset + """ + cdef: + Dataset filtered_dataset + + new_filter = expression + current_filter = self._scan_options.get("filter") + if current_filter is not None and new_filter is not None: + new_filter = current_filter & new_filter + + filtered_dataset = self.__class__.__new__(self.__class__) + filtered_dataset.init(self.wrapped) + filtered_dataset._scan_options = dict(filter=new_filter) + return filtered_dataset + + def sort_by(self, sorting, **kwargs): + """ + Sort the Dataset by one or multiple columns. + + Parameters + ---------- + sorting : str or list[tuple(name, order)] + Name of the column to use to sort (ascending), or + a list of multiple sorting conditions where + each entry is a tuple with column name + and sorting order ("ascending" or "descending") + **kwargs : dict, optional + Additional sorting options. + As allowed by :class:`SortOptions` + + Returns + ------- + InMemoryDataset + A new dataset sorted according to the sort keys. + """ + if isinstance(sorting, str): + sorting = [(sorting, "ascending")] + + res = _pac()._sort_source( + self, output_type=InMemoryDataset, sort_keys=sorting, **kwargs + ) + return res + + def join(self, right_dataset, keys, right_keys=None, join_type="left outer", + left_suffix=None, right_suffix=None, coalesce_keys=True, + use_threads=True): + """ + Perform a join between this dataset and another one. + + Result of the join will be a new dataset, where further + operations can be applied. + + Parameters + ---------- + right_dataset : dataset + The dataset to join to the current one, acting as the right dataset + in the join operation. + keys : str or list[str] + The columns from current dataset that should be used as keys + of the join operation left side. + right_keys : str or list[str], default None + The columns from the right_dataset that should be used as keys + on the join operation right side. + When ``None`` use the same key names as the left dataset. + join_type : str, default "left outer" + The kind of join that should be performed, one of + ("left semi", "right semi", "left anti", "right anti", + "inner", "left outer", "right outer", "full outer") + left_suffix : str, default None + Which suffix to add to right column names. This prevents confusion + when the columns in left and right datasets have colliding names. + right_suffix : str, default None + Which suffix to add to the left column names. This prevents confusion + when the columns in left and right datasets have colliding names. + coalesce_keys : bool, default True + If the duplicated keys should be omitted from one of the sides + in the join result. + use_threads : bool, default True + Whenever to use multithreading or not. + + Returns + ------- + InMemoryDataset + """ + if right_keys is None: + right_keys = keys + return _pac()._perform_join( + join_type, self, keys, right_dataset, right_keys, + left_suffix=left_suffix, right_suffix=right_suffix, + use_threads=use_threads, coalesce_keys=coalesce_keys, + output_type=InMemoryDataset + ) + + def join_asof(self, right_dataset, on, by, tolerance, right_on=None, right_by=None): + """ + Perform an asof join between this dataset and another one. + + This is similar to a left-join except that we match on nearest key rather + than equal keys. Both datasets must be sorted by the key. This type of join + is most useful for time series data that are not perfectly aligned. + + Optionally match on equivalent keys with "by" before searching with "on". + + Result of the join will be a new Dataset, where further + operations can be applied. + + Parameters + ---------- + right_dataset : dataset + The dataset to join to the current one, acting as the right dataset + in the join operation. + on : str + The column from current dataset that should be used as the "on" key + of the join operation left side. + + An inexact match is used on the "on" key, i.e. a row is considered a + match if and only if left_on - tolerance <= right_on <= left_on. + + The input table must be sorted by the "on" key. Must be a single + field of a common type. + + Currently, the "on" key must be an integer, date, or timestamp type. + by : str or list[str] + The columns from current dataset that should be used as the keys + of the join operation left side. The join operation is then done + only for the matches in these columns. + tolerance : int + The tolerance for inexact "on" key matching. A right row is considered + a match with the left row `right.on - left.on <= tolerance`. The + `tolerance` may be: + + - negative, in which case a past-as-of-join occurs; + - or positive, in which case a future-as-of-join occurs; + - or zero, in which case an exact-as-of-join occurs. + + The tolerance is interpreted in the same units as the "on" key. + right_on : str or list[str], default None + The columns from the right_dataset that should be used as the on key + on the join operation right side. + When ``None`` use the same key name as the left dataset. + right_by : str or list[str], default None + The columns from the right_dataset that should be used as by keys + on the join operation right side. + When ``None`` use the same key names as the left dataset. + + Returns + ------- + InMemoryDataset + """ + if right_on is None: + right_on = on + if right_by is None: + right_by = by + return _pac()._perform_join_asof(self, on, by, + right_dataset, right_on, right_by, + tolerance, output_type=InMemoryDataset) + + +cdef class InMemoryDataset(Dataset): + """ + A Dataset wrapping in-memory data. + + Parameters + ---------- + source : RecordBatch, Table, list, tuple + The data for this dataset. Can be a RecordBatch, Table, list of + RecordBatch/Table, iterable of RecordBatch, or a RecordBatchReader + If an iterable is provided, the schema must also be provided. + schema : Schema, optional + Only required if passing an iterable as the source + """ + + cdef: + CInMemoryDataset* in_memory_dataset + + def __init__(self, source, Schema schema=None): + cdef: + shared_ptr[CInMemoryDataset] in_memory_dataset + + if isinstance(source, (pa.RecordBatch, pa.Table)): + source = [source] + + if isinstance(source, (list, tuple)): + batches = [] + for item in source: + if isinstance(item, pa.RecordBatch): + batches.append(item) + elif isinstance(item, pa.Table): + batches.extend(item.to_batches()) + else: + raise TypeError( + 'Expected a list of tables or batches. The given list ' + 'contains a ' + type(item).__name__) + if schema is None: + schema = item.schema + elif not schema.equals(item.schema): + raise ArrowTypeError( + f'Item has schema\n{item.schema}\nwhich does not ' + f'match expected schema\n{schema}') + if not batches and schema is None: + raise ValueError('Must provide schema to construct in-memory ' + 'dataset from an empty list') + table = pa.Table.from_batches(batches, schema=schema) + in_memory_dataset = make_shared[CInMemoryDataset]( + pyarrow_unwrap_table(table)) + else: + raise TypeError( + 'Expected a table, batch, or list of tables/batches ' + 'instead of the given type: ' + + type(source).__name__ + ) + + self.init( in_memory_dataset) + + cdef void init(self, const shared_ptr[CDataset]& sp): + Dataset.init(self, sp) + self.in_memory_dataset = sp.get() + + +cdef class UnionDataset(Dataset): + """ + A Dataset wrapping child datasets. + + Children's schemas must agree with the provided schema. + + Parameters + ---------- + schema : Schema + A known schema to conform to. + children : list of Dataset + One or more input children + """ + + cdef: + CUnionDataset* union_dataset + + def __init__(self, Schema schema not None, children): + cdef: + Dataset child + CDatasetVector c_children + shared_ptr[CUnionDataset] union_dataset + + for child in children: + c_children.push_back(child.wrapped) + + union_dataset = GetResultValue(CUnionDataset.Make( + pyarrow_unwrap_schema(schema), move(c_children))) + self.init( union_dataset) + + cdef void init(self, const shared_ptr[CDataset]& sp): + Dataset.init(self, sp) + self.union_dataset = sp.get() + + def __reduce__(self): + return UnionDataset, (self.schema, self.children) + + @property + def children(self): + cdef CDatasetVector children = self.union_dataset.children() + return [Dataset.wrap(children[i]) for i in range(children.size())] + + +cdef class FileSystemDataset(Dataset): + """ + A Dataset of file fragments. + + A FileSystemDataset is composed of one or more FileFragment. + + Parameters + ---------- + fragments : list[Fragments] + List of fragments to consume. + schema : Schema + The top-level schema of the Dataset. + format : FileFormat + File format of the fragments, currently only ParquetFileFormat, + IpcFileFormat, CsvFileFormat, and JsonFileFormat are supported. + filesystem : FileSystem + FileSystem of the fragments. + root_partition : Expression, optional + The top-level partition of the DataDataset. + """ + + cdef: + CFileSystemDataset* filesystem_dataset + + def __init__(self, fragments, Schema schema, FileFormat format, + FileSystem filesystem=None, root_partition=None): + cdef: + FileFragment fragment=None + vector[shared_ptr[CFileFragment]] c_fragments + CResult[shared_ptr[CDataset]] result + shared_ptr[CFileSystem] c_filesystem + + if root_partition is None: + root_partition = _true + elif not isinstance(root_partition, Expression): + raise TypeError( + "Argument 'root_partition' has incorrect type (expected " + "Expression, got {0})".format(type(root_partition)) + ) + + for fragment in fragments: + c_fragments.push_back( + static_pointer_cast[CFileFragment, CFragment]( + fragment.unwrap())) + + if filesystem is None: + filesystem = fragment.filesystem + + if filesystem is not None: + c_filesystem = filesystem.unwrap() + + result = CFileSystemDataset.Make( + pyarrow_unwrap_schema(schema), + ( root_partition).unwrap(), + format.unwrap(), + c_filesystem, + c_fragments + ) + self.init(GetResultValue(result)) + + @property + def filesystem(self): + return FileSystem.wrap(self.filesystem_dataset.filesystem()) + + @property + def partitioning(self): + """ + The partitioning of the Dataset source, if discovered. + + If the FileSystemDataset is created using the ``dataset()`` factory + function with a partitioning specified, this will return the + finalized Partitioning object from the dataset discovery. In all + other cases, this returns None. + """ + c_partitioning = self.filesystem_dataset.partitioning() + if c_partitioning.get() == nullptr: + return None + try: + return Partitioning.wrap(c_partitioning) + except TypeError: + # e.g. type_name "default" + return None + + cdef void init(self, const shared_ptr[CDataset]& sp): + Dataset.init(self, sp) + self.filesystem_dataset = sp.get() + + def __reduce__(self): + return FileSystemDataset, ( + list(self.get_fragments()), + self.schema, + self.format, + self.filesystem, + self.partition_expression + ) + + @classmethod + def from_paths(cls, paths, schema=None, format=None, + filesystem=None, partitions=None, root_partition=None): + """ + A Dataset created from a list of paths on a particular filesystem. + + Parameters + ---------- + paths : list of str + List of file paths to create the fragments from. + schema : Schema + The top-level schema of the DataDataset. + format : FileFormat + File format to create fragments from, currently only + ParquetFileFormat, IpcFileFormat, CsvFileFormat, and JsonFileFormat are supported. + filesystem : FileSystem + The filesystem which files are from. + partitions : list[Expression], optional + Attach additional partition information for the file paths. + root_partition : Expression, optional + The top-level partition of the DataDataset. + """ + if root_partition is None: + root_partition = _true + + for arg, class_, name in [ + (schema, Schema, 'schema'), + (format, FileFormat, 'format'), + (filesystem, FileSystem, 'filesystem'), + (root_partition, Expression, 'root_partition') + ]: + if not isinstance(arg, class_): + raise TypeError( + "Argument '{0}' has incorrect type (expected {1}, " + "got {2})".format(name, class_.__name__, type(arg)) + ) + + partitions = partitions or [_true] * len(paths) + + if len(paths) != len(partitions): + raise ValueError( + 'The number of files resulting from paths_or_selector ' + 'must be equal to the number of partitions.' + ) + + fragments = [ + format.make_fragment(path, filesystem, partitions[i]) + for i, path in enumerate(paths) + ] + return FileSystemDataset(fragments, schema, format, + filesystem, root_partition) + + @property + def files(self): + """List of the files""" + cdef vector[c_string] files = self.filesystem_dataset.files() + return [frombytes(f) for f in files] + + @property + def format(self): + """The FileFormat of this source.""" + return FileFormat.wrap(self.filesystem_dataset.format()) + + +cdef class FileWriteOptions(_Weakrefable): + + def __init__(self): + _forbid_instantiation(self.__class__) + + cdef void init(self, const shared_ptr[CFileWriteOptions]& sp): + self.wrapped = sp + self.c_options = sp.get() + + @staticmethod + cdef wrap(const shared_ptr[CFileWriteOptions]& sp): + type_name = frombytes(sp.get().type_name()) + + classes = { + 'csv': CsvFileWriteOptions, + 'ipc': IpcFileWriteOptions, + 'parquet': _get_parquet_symbol('ParquetFileWriteOptions'), + } + + class_ = classes.get(type_name, None) + if class_ is None: + raise TypeError(type_name) + + cdef FileWriteOptions self = class_.__new__(class_) + self.init(sp) + return self + + @property + def format(self): + return FileFormat.wrap(self.c_options.format()) + + cdef inline shared_ptr[CFileWriteOptions] unwrap(self): + return self.wrapped + + +cdef class FileFormat(_Weakrefable): + + def __init__(self): + _forbid_instantiation(self.__class__) + + cdef void init(self, const shared_ptr[CFileFormat]& sp): + self.wrapped = sp + self.format = sp.get() + + @staticmethod + cdef wrap(const shared_ptr[CFileFormat]& sp): + type_name = frombytes(sp.get().type_name()) + + classes = { + 'ipc': IpcFileFormat, + 'csv': CsvFileFormat, + 'json': JsonFileFormat, + 'parquet': _get_parquet_symbol('ParquetFileFormat'), + 'orc': _get_orc_fileformat(), + } + + class_ = classes.get(type_name, None) + if class_ is None: + raise TypeError(type_name) + + cdef FileFormat self = class_.__new__(class_) + self.init(sp) + return self + + cdef WrittenFile _finish_write(self, path, base_dir, + CFileWriter* file_writer): + parquet_metadata = None + size = GetResultValue(file_writer.GetBytesWritten()) + return WrittenFile(path, parquet_metadata, size) + + cdef inline shared_ptr[CFileFormat] unwrap(self): + return self.wrapped + + def inspect(self, file, filesystem=None): + """ + Infer the schema of a file. + + Parameters + ---------- + file : file-like object, path-like or str + The file or file path to infer a schema from. + filesystem : Filesystem, optional + If `filesystem` is given, `file` must be a string and specifies + the path of the file to read from the filesystem. + + Returns + ------- + schema : Schema + The schema inferred from the file + """ + cdef: + CFileSource c_source = _make_file_source(file, filesystem, file_size=None) + CResult[shared_ptr[CSchema]] c_result + with nogil: + c_result = self.format.Inspect(c_source) + c_schema = GetResultValue(c_result) + return pyarrow_wrap_schema(move(c_schema)) + + def make_fragment(self, file, filesystem=None, + Expression partition_expression=None, + *, file_size=None): + """ + Make a FileFragment from a given file. + + Parameters + ---------- + file : file-like object, path-like or str + The file or file path to make a fragment from. + filesystem : Filesystem, optional + If `filesystem` is given, `file` must be a string and specifies + the path of the file to read from the filesystem. + partition_expression : Expression, optional + An expression that is guaranteed true for all rows in the fragment. Allows + fragment to be potentially skipped while scanning with a filter. + file_size : int, optional + The size of the file in bytes. Can improve performance with high-latency filesystems + when file size needs to be known before reading. + + Returns + ------- + fragment : Fragment + The file fragment + """ + if partition_expression is None: + partition_expression = _true + c_source = _make_file_source(file, filesystem, file_size) + c_fragment = GetResultValue( + self.format.MakeFragment(move(c_source), + partition_expression.unwrap(), + nullptr)) + return Fragment.wrap(move(c_fragment)) + + def make_write_options(self): + sp_write_options = self.format.DefaultWriteOptions() + if sp_write_options.get() == nullptr: + # DefaultWriteOptions() may return `nullptr` which means that + # the format does not yet support writing datasets. + raise NotImplementedError( + "Writing datasets not yet implemented for this file format." + ) + return FileWriteOptions.wrap(sp_write_options) + + @property + def default_extname(self): + return frombytes(self.format.type_name()) + + @property + def default_fragment_scan_options(self): + dfso = FragmentScanOptions.wrap( + self.wrapped.get().default_fragment_scan_options) + # CsvFileFormat stores a Python-specific encoding field that needs + # to be restored because it does not exist in the C++ struct + if isinstance(self, CsvFileFormat): + if self._read_options_py is not None: + dfso.read_options = self._read_options_py + return dfso + + @default_fragment_scan_options.setter + def default_fragment_scan_options(self, FragmentScanOptions options): + if options is None: + self.wrapped.get().default_fragment_scan_options =\ + nullptr + else: + self._set_default_fragment_scan_options(options) + + cdef _set_default_fragment_scan_options(self, FragmentScanOptions options): + raise ValueError(f"Cannot set fragment scan options for " + f"'{options.type_name}' on {self.__class__.__name__}") + + def __eq__(self, other): + try: + return self.equals(other) + except TypeError: + return False + + +cdef class Fragment(_Weakrefable): + """Fragment of data from a Dataset.""" + + def __init__(self): + _forbid_instantiation(self.__class__) + + cdef void init(self, const shared_ptr[CFragment]& sp): + self.wrapped = sp + self.fragment = sp.get() + + @staticmethod + cdef wrap(const shared_ptr[CFragment]& sp): + type_name = frombytes(sp.get().type_name()) + + classes = { + # IpcFileFormat, CsvFileFormat, JsonFileFormat and OrcFileFormat do not have + # corresponding subclasses of FileFragment + 'ipc': FileFragment, + 'csv': FileFragment, + 'json': FileFragment, + 'orc': FileFragment, + 'parquet': _get_parquet_symbol('ParquetFileFragment'), + } + + class_ = classes.get(type_name, None) + if class_ is None: + class_ = Fragment + + cdef Fragment self = class_.__new__(class_) + self.init(sp) + return self + + cdef inline shared_ptr[CFragment] unwrap(self): + return self.wrapped + + @property + def physical_schema(self): + """Return the physical schema of this Fragment. This schema can be + different from the dataset read schema.""" + cdef: + CResult[shared_ptr[CSchema]] maybe_schema + with nogil: + maybe_schema = self.fragment.ReadPhysicalSchema() + return pyarrow_wrap_schema(GetResultValue(maybe_schema)) + + @property + def partition_expression(self): + """An Expression which evaluates to true for all data viewed by this + Fragment. + """ + return Expression.wrap(self.fragment.partition_expression()) + + def scanner(self, + Schema schema=None, + object columns=None, + Expression filter=None, + int batch_size=_DEFAULT_BATCH_SIZE, + int batch_readahead=_DEFAULT_BATCH_READAHEAD, + int fragment_readahead=_DEFAULT_FRAGMENT_READAHEAD, + FragmentScanOptions fragment_scan_options=None, + bint use_threads=True, + MemoryPool memory_pool=None): + """ + Build a scan operation against the fragment. + + Data is not loaded immediately. Instead, this produces a Scanner, + which exposes further operations (e.g. loading all data as a + table, counting rows). + + Parameters + ---------- + schema : Schema + Schema to use for scanning. This is used to unify a Fragment to + its Dataset's schema. If not specified this will use the + Fragment's physical schema which might differ for each Fragment. + columns : list of str, default None + The columns to project. This can be a list of column names to + include (order and duplicates will be preserved), or a dictionary + with {new_column_name: expression} values for more advanced + projections. + + The list of columns or expressions may use the special fields + `__batch_index` (the index of the batch within the fragment), + `__fragment_index` (the index of the fragment within the dataset), + `__last_in_fragment` (whether the batch is last in fragment), and + `__filename` (the name of the source file or a description of the + source fragment). + + The columns will be passed down to Datasets and corresponding data + fragments to avoid loading, copying, and deserializing columns + that will not be required further down the compute chain. + By default all of the available columns are projected. Raises + an exception if any of the referenced column names does not exist + in the dataset's Schema. + filter : Expression, default None + Scan will return only the rows matching the filter. + If possible the predicate will be pushed down to exploit the + partition information or internal metadata found in the data + source, e.g. Parquet statistics. Otherwise filters the loaded + RecordBatches before yielding them. + batch_size : int, default 131_072 + The maximum row count for scanned record batches. If scanned + record batches are overflowing memory then this method can be + called to reduce their size. + batch_readahead : int, default 16 + The number of batches to read ahead in a file. This might not work + for all file formats. Increasing this number will increase + RAM usage but could also improve IO utilization. + fragment_readahead : int, default 4 + The number of files to read ahead. Increasing this number will increase + RAM usage but could also improve IO utilization. + fragment_scan_options : FragmentScanOptions, default None + Options specific to a particular scan and fragment type, which + can change between different scans of the same dataset. + use_threads : bool, default True + If enabled, then maximum parallelism will be used determined by + the number of available CPU cores. + memory_pool : MemoryPool, default None + For memory allocations, if required. If not specified, uses the + default pool. + + Returns + ------- + scanner : Scanner + """ + return Scanner.from_fragment( + self, + schema=schema, + columns=columns, + filter=filter, + batch_size=batch_size, + batch_readahead=batch_readahead, + fragment_readahead=fragment_readahead, + fragment_scan_options=fragment_scan_options, + use_threads=use_threads, + memory_pool=memory_pool + ) + + def to_batches(self, + Schema schema=None, + object columns=None, + Expression filter=None, + int batch_size=_DEFAULT_BATCH_SIZE, + int batch_readahead=_DEFAULT_BATCH_READAHEAD, + int fragment_readahead=_DEFAULT_FRAGMENT_READAHEAD, + FragmentScanOptions fragment_scan_options=None, + bint use_threads=True, + MemoryPool memory_pool=None): + """ + Read the fragment as materialized record batches. + + Parameters + ---------- + schema : Schema, optional + Concrete schema to use for scanning. + columns : list of str, default None + The columns to project. This can be a list of column names to + include (order and duplicates will be preserved), or a dictionary + with {new_column_name: expression} values for more advanced + projections. + + The list of columns or expressions may use the special fields + `__batch_index` (the index of the batch within the fragment), + `__fragment_index` (the index of the fragment within the dataset), + `__last_in_fragment` (whether the batch is last in fragment), and + `__filename` (the name of the source file or a description of the + source fragment). + + The columns will be passed down to Datasets and corresponding data + fragments to avoid loading, copying, and deserializing columns + that will not be required further down the compute chain. + By default all of the available columns are projected. Raises + an exception if any of the referenced column names does not exist + in the dataset's Schema. + filter : Expression, default None + Scan will return only the rows matching the filter. + If possible the predicate will be pushed down to exploit the + partition information or internal metadata found in the data + source, e.g. Parquet statistics. Otherwise filters the loaded + RecordBatches before yielding them. + batch_size : int, default 131_072 + The maximum row count for scanned record batches. If scanned + record batches are overflowing memory then this method can be + called to reduce their size. + batch_readahead : int, default 16 + The number of batches to read ahead in a file. This might not work + for all file formats. Increasing this number will increase + RAM usage but could also improve IO utilization. + fragment_readahead : int, default 4 + The number of files to read ahead. Increasing this number will increase + RAM usage but could also improve IO utilization. + fragment_scan_options : FragmentScanOptions, default None + Options specific to a particular scan and fragment type, which + can change between different scans of the same dataset. + use_threads : bool, default True + If enabled, then maximum parallelism will be used determined by + the number of available CPU cores. + memory_pool : MemoryPool, default None + For memory allocations, if required. If not specified, uses the + default pool. + + Returns + ------- + record_batches : iterator of RecordBatch + """ + return Scanner.from_fragment( + self, + schema=schema, + columns=columns, + filter=filter, + batch_size=batch_size, + batch_readahead=batch_readahead, + fragment_readahead=fragment_readahead, + fragment_scan_options=fragment_scan_options, + use_threads=use_threads, + memory_pool=memory_pool + ).to_batches() + + def to_table(self, + Schema schema=None, + object columns=None, + Expression filter=None, + int batch_size=_DEFAULT_BATCH_SIZE, + int batch_readahead=_DEFAULT_BATCH_READAHEAD, + int fragment_readahead=_DEFAULT_FRAGMENT_READAHEAD, + FragmentScanOptions fragment_scan_options=None, + bint use_threads=True, + MemoryPool memory_pool=None): + """ + Convert this Fragment into a Table. + + Use this convenience utility with care. This will serially materialize + the Scan result in memory before creating the Table. + + Parameters + ---------- + schema : Schema, optional + Concrete schema to use for scanning. + columns : list of str, default None + The columns to project. This can be a list of column names to + include (order and duplicates will be preserved), or a dictionary + with {new_column_name: expression} values for more advanced + projections. + + The list of columns or expressions may use the special fields + `__batch_index` (the index of the batch within the fragment), + `__fragment_index` (the index of the fragment within the dataset), + `__last_in_fragment` (whether the batch is last in fragment), and + `__filename` (the name of the source file or a description of the + source fragment). + + The columns will be passed down to Datasets and corresponding data + fragments to avoid loading, copying, and deserializing columns + that will not be required further down the compute chain. + By default all of the available columns are projected. Raises + an exception if any of the referenced column names does not exist + in the dataset's Schema. + filter : Expression, default None + Scan will return only the rows matching the filter. + If possible the predicate will be pushed down to exploit the + partition information or internal metadata found in the data + source, e.g. Parquet statistics. Otherwise filters the loaded + RecordBatches before yielding them. + batch_size : int, default 131_072 + The maximum row count for scanned record batches. If scanned + record batches are overflowing memory then this method can be + called to reduce their size. + batch_readahead : int, default 16 + The number of batches to read ahead in a file. This might not work + for all file formats. Increasing this number will increase + RAM usage but could also improve IO utilization. + fragment_readahead : int, default 4 + The number of files to read ahead. Increasing this number will increase + RAM usage but could also improve IO utilization. + fragment_scan_options : FragmentScanOptions, default None + Options specific to a particular scan and fragment type, which + can change between different scans of the same dataset. + use_threads : bool, default True + If enabled, then maximum parallelism will be used determined by + the number of available CPU cores. + memory_pool : MemoryPool, default None + For memory allocations, if required. If not specified, uses the + default pool. + + Returns + ------- + table : Table + """ + return self.scanner( + schema=schema, + columns=columns, + filter=filter, + batch_size=batch_size, + batch_readahead=batch_readahead, + fragment_readahead=fragment_readahead, + fragment_scan_options=fragment_scan_options, + use_threads=use_threads, + memory_pool=memory_pool + ).to_table() + + def take(self, + object indices, + object columns=None, + Expression filter=None, + int batch_size=_DEFAULT_BATCH_SIZE, + int batch_readahead=_DEFAULT_BATCH_READAHEAD, + int fragment_readahead=_DEFAULT_FRAGMENT_READAHEAD, + FragmentScanOptions fragment_scan_options=None, + bint use_threads=True, + MemoryPool memory_pool=None): + """ + Select rows of data by index. + + Parameters + ---------- + indices : Array or array-like + The indices of row to select in the dataset. + columns : list of str, default None + The columns to project. This can be a list of column names to + include (order and duplicates will be preserved), or a dictionary + with {new_column_name: expression} values for more advanced + projections. + + The list of columns or expressions may use the special fields + `__batch_index` (the index of the batch within the fragment), + `__fragment_index` (the index of the fragment within the dataset), + `__last_in_fragment` (whether the batch is last in fragment), and + `__filename` (the name of the source file or a description of the + source fragment). + + The columns will be passed down to Datasets and corresponding data + fragments to avoid loading, copying, and deserializing columns + that will not be required further down the compute chain. + By default all of the available columns are projected. Raises + an exception if any of the referenced column names does not exist + in the dataset's Schema. + filter : Expression, default None + Scan will return only the rows matching the filter. + If possible the predicate will be pushed down to exploit the + partition information or internal metadata found in the data + source, e.g. Parquet statistics. Otherwise filters the loaded + RecordBatches before yielding them. + batch_size : int, default 131_072 + The maximum row count for scanned record batches. If scanned + record batches are overflowing memory then this method can be + called to reduce their size. + batch_readahead : int, default 16 + The number of batches to read ahead in a file. This might not work + for all file formats. Increasing this number will increase + RAM usage but could also improve IO utilization. + fragment_readahead : int, default 4 + The number of files to read ahead. Increasing this number will increase + RAM usage but could also improve IO utilization. + fragment_scan_options : FragmentScanOptions, default None + Options specific to a particular scan and fragment type, which + can change between different scans of the same dataset. + use_threads : bool, default True + If enabled, then maximum parallelism will be used determined by + the number of available CPU cores. + memory_pool : MemoryPool, default None + For memory allocations, if required. If not specified, uses the + default pool. + + Returns + ------- + Table + """ + return self.scanner( + columns=columns, + filter=filter, + batch_size=batch_size, + batch_readahead=batch_readahead, + fragment_readahead=fragment_readahead, + fragment_scan_options=fragment_scan_options, + use_threads=use_threads, + memory_pool=memory_pool + ).take(indices) + + def head(self, + int num_rows, + object columns=None, + Expression filter=None, + int batch_size=_DEFAULT_BATCH_SIZE, + int batch_readahead=_DEFAULT_BATCH_READAHEAD, + int fragment_readahead=_DEFAULT_FRAGMENT_READAHEAD, + FragmentScanOptions fragment_scan_options=None, + bint use_threads=True, + MemoryPool memory_pool=None): + """ + Load the first N rows of the fragment. + + Parameters + ---------- + num_rows : int + The number of rows to load. + columns : list of str, default None + The columns to project. This can be a list of column names to + include (order and duplicates will be preserved), or a dictionary + with {new_column_name: expression} values for more advanced + projections. + + The list of columns or expressions may use the special fields + `__batch_index` (the index of the batch within the fragment), + `__fragment_index` (the index of the fragment within the dataset), + `__last_in_fragment` (whether the batch is last in fragment), and + `__filename` (the name of the source file or a description of the + source fragment). + + The columns will be passed down to Datasets and corresponding data + fragments to avoid loading, copying, and deserializing columns + that will not be required further down the compute chain. + By default all of the available columns are projected. Raises + an exception if any of the referenced column names does not exist + in the dataset's Schema. + filter : Expression, default None + Scan will return only the rows matching the filter. + If possible the predicate will be pushed down to exploit the + partition information or internal metadata found in the data + source, e.g. Parquet statistics. Otherwise filters the loaded + RecordBatches before yielding them. + batch_size : int, default 131_072 + The maximum row count for scanned record batches. If scanned + record batches are overflowing memory then this method can be + called to reduce their size. + batch_readahead : int, default 16 + The number of batches to read ahead in a file. This might not work + for all file formats. Increasing this number will increase + RAM usage but could also improve IO utilization. + fragment_readahead : int, default 4 + The number of files to read ahead. Increasing this number will increase + RAM usage but could also improve IO utilization. + fragment_scan_options : FragmentScanOptions, default None + Options specific to a particular scan and fragment type, which + can change between different scans of the same dataset. + use_threads : bool, default True + If enabled, then maximum parallelism will be used determined by + the number of available CPU cores. + memory_pool : MemoryPool, default None + For memory allocations, if required. If not specified, uses the + default pool. + + Returns + ------- + Table + """ + return self.scanner( + columns=columns, + filter=filter, + batch_size=batch_size, + batch_readahead=batch_readahead, + fragment_readahead=fragment_readahead, + fragment_scan_options=fragment_scan_options, + use_threads=use_threads, + memory_pool=memory_pool + ).head(num_rows) + + def count_rows(self, + Expression filter=None, + int batch_size=_DEFAULT_BATCH_SIZE, + int batch_readahead=_DEFAULT_BATCH_READAHEAD, + int fragment_readahead=_DEFAULT_FRAGMENT_READAHEAD, + FragmentScanOptions fragment_scan_options=None, + bint use_threads=True, + MemoryPool memory_pool=None): + """ + Count rows matching the scanner filter. + + Parameters + ---------- + filter : Expression, default None + Scan will return only the rows matching the filter. + If possible the predicate will be pushed down to exploit the + partition information or internal metadata found in the data + source, e.g. Parquet statistics. Otherwise filters the loaded + RecordBatches before yielding them. + batch_size : int, default 131_072 + The maximum row count for scanned record batches. If scanned + record batches are overflowing memory then this method can be + called to reduce their size. + batch_readahead : int, default 16 + The number of batches to read ahead in a file. This might not work + for all file formats. Increasing this number will increase + RAM usage but could also improve IO utilization. + fragment_readahead : int, default 4 + The number of files to read ahead. Increasing this number will increase + RAM usage but could also improve IO utilization. + fragment_scan_options : FragmentScanOptions, default None + Options specific to a particular scan and fragment type, which + can change between different scans of the same dataset. + use_threads : bool, default True + If enabled, then maximum parallelism will be used determined by + the number of available CPU cores. + memory_pool : MemoryPool, default None + For memory allocations, if required. If not specified, uses the + default pool. + + Returns + ------- + count : int + """ + return self.scanner( + filter=filter, + batch_size=batch_size, + batch_readahead=batch_readahead, + fragment_readahead=fragment_readahead, + fragment_scan_options=fragment_scan_options, + use_threads=use_threads, + memory_pool=memory_pool + ).count_rows() + + +cdef class FileFragment(Fragment): + """A Fragment representing a data file.""" + + cdef void init(self, const shared_ptr[CFragment]& sp): + Fragment.init(self, sp) + self.file_fragment = sp.get() + + def __repr__(self): + type_name = frombytes(self.fragment.type_name()) + if type_name != "parquet": + typ = f" type={type_name}" + else: + # parquet has a subclass -> type embedded in class name + typ = "" + partition_dict = get_partition_keys(self.partition_expression) + partition = ", ".join( + [f"{key}={val}" for key, val in partition_dict.items()] + ) + if partition: + partition = f" partition=[{partition}]" + return "".format( + self.__class__.__name__, typ, self.path, partition + ) + + def __reduce__(self): + buffer = self.buffer + return self.format.make_fragment, ( + self.path if buffer is None else buffer, + self.filesystem, + self.partition_expression + ) + + def open(self): + """ + Open a NativeFile of the buffer or file viewed by this fragment. + """ + cdef: + shared_ptr[CFileSystem] c_filesystem + shared_ptr[CRandomAccessFile] opened + c_string c_path + NativeFile out = NativeFile() + + if self.buffer is not None: + return pa.BufferReader(self.buffer) + + c_path = tobytes(self.file_fragment.source().path()) + with nogil: + c_filesystem = self.file_fragment.source().filesystem() + opened = GetResultValue(c_filesystem.get().OpenInputFile(c_path)) + + out.set_random_access_file(opened) + out.is_readable = True + return out + + @property + def path(self): + """ + The path of the data file viewed by this fragment, if it views a + file. If instead it views a buffer, this will be "". + """ + return frombytes(self.file_fragment.source().path()) + + @property + def filesystem(self): + """ + The FileSystem containing the data file viewed by this fragment, if + it views a file. If instead it views a buffer, this will be None. + """ + cdef: + shared_ptr[CFileSystem] c_fs + c_fs = self.file_fragment.source().filesystem() + + if c_fs.get() == nullptr: + return None + + return FileSystem.wrap(c_fs) + + @property + def buffer(self): + """ + The buffer viewed by this fragment, if it views a buffer. If + instead it views a file, this will be None. + """ + cdef: + shared_ptr[CBuffer] c_buffer + c_buffer = self.file_fragment.source().buffer() + + if c_buffer.get() == nullptr: + return None + + return pyarrow_wrap_buffer(c_buffer) + + @property + def format(self): + """ + The format of the data file viewed by this fragment. + """ + return FileFormat.wrap(self.file_fragment.format()) + + +cdef class FragmentScanOptions(_Weakrefable): + """Scan options specific to a particular fragment and scan operation.""" + + def __init__(self): + _forbid_instantiation(self.__class__) + + cdef void init(self, const shared_ptr[CFragmentScanOptions]& sp): + self.wrapped = sp + + @staticmethod + cdef wrap(const shared_ptr[CFragmentScanOptions]& sp): + if not sp: + return None + + type_name = frombytes(sp.get().type_name()) + + classes = { + 'csv': CsvFragmentScanOptions, + 'json': JsonFragmentScanOptions, + 'parquet': _get_parquet_symbol('ParquetFragmentScanOptions'), + } + + class_ = classes.get(type_name, None) + if class_ is None: + raise TypeError(type_name) + + cdef FragmentScanOptions self = class_.__new__(class_) + self.init(sp) + return self + + @property + def type_name(self): + return frombytes(self.wrapped.get().type_name()) + + def __eq__(self, other): + try: + return self.equals(other) + except TypeError: + return False + + +cdef class IpcFileWriteOptions(FileWriteOptions): + cdef: + CIpcFileWriteOptions* ipc_options + + def __init__(self): + _forbid_instantiation(self.__class__) + + @property + def write_options(self): + out = IpcWriteOptions() + out.c_options = CIpcWriteOptions(deref(self.ipc_options.options)) + return out + + @write_options.setter + def write_options(self, IpcWriteOptions write_options not None): + self.ipc_options.options.reset( + new CIpcWriteOptions(write_options.c_options)) + + cdef void init(self, const shared_ptr[CFileWriteOptions]& sp): + FileWriteOptions.init(self, sp) + self.ipc_options = sp.get() + + +cdef class IpcFileFormat(FileFormat): + + def __init__(self): + self.init(shared_ptr[CFileFormat](new CIpcFileFormat())) + + def equals(self, IpcFileFormat other): + """ + Parameters + ---------- + other : pyarrow.dataset.IpcFileFormat + + Returns + ------- + True + """ + return True + + def make_write_options(self, **kwargs): + """ + Parameters + ---------- + **kwargs : dict + + Returns + ------- + pyarrow.ipc.IpcWriteOptions + """ + cdef IpcFileWriteOptions opts = \ + FileFormat.make_write_options(self) + opts.write_options = IpcWriteOptions(**kwargs) + return opts + + @property + def default_extname(self): + return "arrow" + + def __reduce__(self): + return IpcFileFormat, tuple() + + +cdef class FeatherFileFormat(IpcFileFormat): + + @property + def default_extname(self): + return "feather" + + +cdef class CsvFileFormat(FileFormat): + """ + FileFormat for CSV files. + + Parameters + ---------- + parse_options : pyarrow.csv.ParseOptions + Options regarding CSV parsing. + default_fragment_scan_options : CsvFragmentScanOptions + Default options for fragments scan. + convert_options : pyarrow.csv.ConvertOptions + Options regarding value conversion. + read_options : pyarrow.csv.ReadOptions + General read options. + """ + cdef: + CCsvFileFormat* csv_format + # The encoding field in ReadOptions does not exist in the C++ struct. + # We need to store it here and override it when reading + # default_fragment_scan_options.read_options + public ReadOptions _read_options_py + + # Avoid mistakingly creating attributes + __slots__ = () + + def __init__(self, ParseOptions parse_options=None, + default_fragment_scan_options=None, + ConvertOptions convert_options=None, + ReadOptions read_options=None): + self.init(shared_ptr[CFileFormat](new CCsvFileFormat())) + if parse_options is not None: + self.parse_options = parse_options + if convert_options is not None or read_options is not None: + if default_fragment_scan_options: + raise ValueError('If `default_fragment_scan_options` is ' + 'given, cannot specify convert_options ' + 'or read_options') + self.default_fragment_scan_options = CsvFragmentScanOptions( + convert_options=convert_options, read_options=read_options) + elif isinstance(default_fragment_scan_options, dict): + self.default_fragment_scan_options = CsvFragmentScanOptions( + **default_fragment_scan_options) + elif isinstance(default_fragment_scan_options, CsvFragmentScanOptions): + self.default_fragment_scan_options = default_fragment_scan_options + elif default_fragment_scan_options is not None: + raise TypeError('`default_fragment_scan_options` must be either ' + 'a dictionary or an instance of ' + 'CsvFragmentScanOptions') + if read_options is not None: + self._read_options_py = read_options + + cdef void init(self, const shared_ptr[CFileFormat]& sp): + FileFormat.init(self, sp) + self.csv_format = sp.get() + + def make_write_options(self, **kwargs): + """ + Parameters + ---------- + **kwargs : dict + + Returns + ------- + pyarrow.csv.WriteOptions + """ + cdef CsvFileWriteOptions opts = \ + FileFormat.make_write_options(self) + opts.write_options = WriteOptions(**kwargs) + return opts + + @property + def parse_options(self): + return ParseOptions.wrap(self.csv_format.parse_options) + + @parse_options.setter + def parse_options(self, ParseOptions parse_options not None): + self.csv_format.parse_options = deref(parse_options.options) + + cdef _set_default_fragment_scan_options(self, FragmentScanOptions options): + if options.type_name == 'csv': + self.csv_format.default_fragment_scan_options = options.wrapped + self.default_fragment_scan_options.read_options = options.read_options + self._read_options_py = options.read_options + else: + super()._set_default_fragment_scan_options(options) + + def equals(self, CsvFileFormat other): + """ + Parameters + ---------- + other : pyarrow.dataset.CsvFileFormat + + Returns + ------- + bool + """ + return ( + self.parse_options.equals(other.parse_options) and + self.default_fragment_scan_options == + other.default_fragment_scan_options) + + def __reduce__(self): + return CsvFileFormat, (self.parse_options, + self.default_fragment_scan_options) + + def __repr__(self): + return f"" + + +cdef class CsvFragmentScanOptions(FragmentScanOptions): + """ + Scan-specific options for CSV fragments. + + Parameters + ---------- + convert_options : pyarrow.csv.ConvertOptions + Options regarding value conversion. + read_options : pyarrow.csv.ReadOptions + General read options. + """ + + cdef: + CCsvFragmentScanOptions* csv_options + # The encoding field in ReadOptions does not exist in the C++ struct. + # We need to store it here and override it when reading read_options + ReadOptions _read_options_py + + # Avoid mistakingly creating attributes + __slots__ = () + + def __init__(self, ConvertOptions convert_options=None, + ReadOptions read_options=None): + self.init(shared_ptr[CFragmentScanOptions]( + new CCsvFragmentScanOptions())) + if convert_options is not None: + self.convert_options = convert_options + if read_options is not None: + self.read_options = read_options + self._read_options_py = read_options + + cdef void init(self, const shared_ptr[CFragmentScanOptions]& sp): + FragmentScanOptions.init(self, sp) + self.csv_options = sp.get() + + @property + def convert_options(self): + return ConvertOptions.wrap(self.csv_options.convert_options) + + @convert_options.setter + def convert_options(self, ConvertOptions convert_options not None): + self.csv_options.convert_options = deref(convert_options.options) + + @property + def read_options(self): + read_options = ReadOptions.wrap(self.csv_options.read_options) + if self._read_options_py is not None: + read_options.encoding = self._read_options_py.encoding + return read_options + + @read_options.setter + def read_options(self, ReadOptions read_options not None): + self.csv_options.read_options = deref(read_options.options) + self._read_options_py = read_options + if codecs.lookup(read_options.encoding).name != 'utf-8': + self.csv_options.stream_transform_func = deref( + make_streamwrap_func(read_options.encoding, 'utf-8')) + + def equals(self, CsvFragmentScanOptions other): + """ + Parameters + ---------- + other : pyarrow.dataset.CsvFragmentScanOptions + + Returns + ------- + bool + """ + return ( + other and + self.convert_options.equals(other.convert_options) and + self.read_options.equals(other.read_options)) + + def __reduce__(self): + return CsvFragmentScanOptions, (self.convert_options, + self.read_options) + + +cdef class CsvFileWriteOptions(FileWriteOptions): + cdef: + CCsvFileWriteOptions* csv_options + object _properties + + def __init__(self): + _forbid_instantiation(self.__class__) + + @property + def write_options(self): + return WriteOptions.wrap(deref(self.csv_options.write_options)) + + @write_options.setter + def write_options(self, WriteOptions write_options not None): + self.csv_options.write_options.reset( + new CCSVWriteOptions(deref(write_options.options))) + + cdef void init(self, const shared_ptr[CFileWriteOptions]& sp): + FileWriteOptions.init(self, sp) + self.csv_options = sp.get() + + +cdef class JsonFileFormat(FileFormat): + """ + FileFormat for JSON files. + + Parameters + ---------- + default_fragment_scan_options : JsonFragmentScanOptions + Default options for fragments scan. + parse_options : pyarrow.json.ParseOptions + Options regarding json parsing. + read_options : pyarrow.json.ReadOptions + General read options. + """ + cdef: + CJsonFileFormat* json_format + + # Avoid mistakingly creating attributes + __slots__ = () + + def __init__(self, default_fragment_scan_options=None, + JsonParseOptions parse_options=None, + JsonReadOptions read_options=None): + self.init(shared_ptr[CFileFormat](new CJsonFileFormat())) + if parse_options is not None or read_options is not None: + if default_fragment_scan_options is not None: + raise ValueError('If `default_fragment_scan_options` is ' + 'given, cannot specify read_options') + self.default_fragment_scan_options = JsonFragmentScanOptions( + parse_options=parse_options, + read_options=read_options) + elif isinstance(default_fragment_scan_options, dict): + self.default_fragment_scan_options = JsonFragmentScanOptions( + **default_fragment_scan_options) + elif isinstance(default_fragment_scan_options, JsonFragmentScanOptions): + self.default_fragment_scan_options = default_fragment_scan_options + elif default_fragment_scan_options is not None: + raise TypeError('`default_fragment_scan_options` must be either ' + 'a dictionary or an instance of ' + 'JsonFragmentScanOptions') + + cdef void init(self, const shared_ptr[CFileFormat]& sp): + FileFormat.init(self, sp) + self.json_format = sp.get() + + cdef _set_default_fragment_scan_options(self, FragmentScanOptions options): + if options.type_name == 'json': + self.json_format.default_fragment_scan_options = options.wrapped + self.default_fragment_scan_options.read_options = options.read_options + self.default_fragment_scan_options.parse_options = options.parse_options + else: + super()._set_default_fragment_scan_options(options) + + def equals(self, JsonFileFormat other): + """ + Parameters + ---------- + other : pyarrow.dataset.JsonFileFormat + + Returns + ------- + bool + """ + return (other and + self.default_fragment_scan_options == + other.default_fragment_scan_options) + + def __reduce__(self): + return JsonFileFormat, (self.default_fragment_scan_options,) + + def __repr__(self): + return "" + + +cdef class JsonFragmentScanOptions(FragmentScanOptions): + """ + Scan-specific options for JSON fragments. + + Parameters + ---------- + parse_options : pyarrow.json.ParseOptions + Options regarding JSON parsing. + read_options : pyarrow.json.ReadOptions + General read options. + """ + cdef: + CJsonFragmentScanOptions* json_options + + # Avoid mistakingly creating attributes + __slots__ = () + + def __init__(self, JsonParseOptions parse_options=None, + JsonReadOptions read_options=None): + self.init(shared_ptr[CFragmentScanOptions]( + new CJsonFragmentScanOptions())) + if parse_options is not None: + self.parse_options = parse_options + if read_options is not None: + self.read_options = read_options + + cdef void init(self, const shared_ptr[CFragmentScanOptions]& sp): + FragmentScanOptions.init(self, sp) + self.json_options = sp.get() + + @property + def parse_options(self): + return JsonParseOptions.wrap(self.json_options.parse_options) + + @parse_options.setter + def parse_options(self, JsonParseOptions parse_options not None): + self.json_options.parse_options = parse_options.options + + @property + def read_options(self): + return JsonReadOptions.wrap(self.json_options.read_options) + + @read_options.setter + def read_options(self, JsonReadOptions read_options not None): + self.json_options.read_options = read_options.options + + def equals(self, JsonFragmentScanOptions other): + """ + Parameters + ---------- + other : pyarrow.dataset.JsonFragmentScanOptions + + Returns + ------- + bool + """ + return ( + other and + self.read_options.equals(other.read_options) and + self.parse_options.equals(other.parse_options)) + + def __reduce__(self): + return JsonFragmentScanOptions, (self.parse_options, self.read_options) + + +cdef class Partitioning(_Weakrefable): + + def __init__(self): + _forbid_instantiation(self.__class__) + + cdef init(self, const shared_ptr[CPartitioning]& sp): + self.wrapped = sp + self.partitioning = sp.get() + + @staticmethod + cdef wrap(const shared_ptr[CPartitioning]& sp): + type_name = frombytes(sp.get().type_name()) + + classes = { + 'directory': DirectoryPartitioning, + 'hive': HivePartitioning, + 'filename': FilenamePartitioning, + } + + class_ = classes.get(type_name, None) + if class_ is None: + raise TypeError(type_name) + + cdef Partitioning self = class_.__new__(class_) + self.init(sp) + return self + + cdef inline shared_ptr[CPartitioning] unwrap(self): + return self.wrapped + + def __eq__(self, other): + if isinstance(other, Partitioning): + return self.partitioning.Equals(deref((other).unwrap())) + return False + + def parse(self, path): + """ + Parse a path into a partition expression. + + Parameters + ---------- + path : str + + Returns + ------- + pyarrow.dataset.Expression + """ + cdef CResult[CExpression] result + result = self.partitioning.Parse(tobytes(path)) + return Expression.wrap(GetResultValue(result)) + + def format(self, expr): + """ + Convert a filter expression into a tuple of (directory, filename) using + the current partitioning scheme + + Parameters + ---------- + expr : pyarrow.dataset.Expression + + Returns + ------- + tuple[str, str] + + Examples + -------- + + Specify the Schema for paths like "/2009/June": + + >>> import pyarrow as pa + >>> import pyarrow.dataset as ds + >>> import pyarrow.compute as pc + >>> part = ds.partitioning(pa.schema([("year", pa.int16()), + ... ("month", pa.string())])) + >>> part.format( + ... (pc.field("year") == 1862) & (pc.field("month") == "Jan") + ... ) + ('1862/Jan', '') + """ + cdef: + CPartitionPathFormat result + + result = GetResultValue(self.partitioning.Format( + Expression.unwrap(expr) + )) + + return frombytes(result.directory), frombytes(result.filename) + + @property + def schema(self): + """The arrow Schema attached to the partitioning.""" + return pyarrow_wrap_schema(self.partitioning.schema()) + + +cdef class PartitioningFactory(_Weakrefable): + + def __init__(self): + _forbid_instantiation(self.__class__) + + cdef init(self, const shared_ptr[CPartitioningFactory]& sp): + self.wrapped = sp + self.factory = sp.get() + + @staticmethod + cdef wrap(const shared_ptr[CPartitioningFactory]& sp, + object constructor, object options): + cdef PartitioningFactory self = PartitioningFactory.__new__( + PartitioningFactory + ) + self.init(sp) + self.constructor = constructor + self.options = options + return self + + cdef inline shared_ptr[CPartitioningFactory] unwrap(self): + return self.wrapped + + def __reduce__(self): + return self.constructor, self.options + + @property + def type_name(self): + return frombytes(self.factory.type_name()) + + +cdef vector[shared_ptr[CArray]] _partitioning_dictionaries( + Schema schema, dictionaries) except *: + cdef: + vector[shared_ptr[CArray]] c_dictionaries + + dictionaries = dictionaries or {} + + for field in schema: + dictionary = dictionaries.get(field.name) + + if (isinstance(field.type, pa.DictionaryType) and + dictionary is not None): + c_dictionaries.push_back(pyarrow_unwrap_array(dictionary)) + else: + c_dictionaries.push_back( nullptr) + + return c_dictionaries + + +cdef class KeyValuePartitioning(Partitioning): + + cdef: + CKeyValuePartitioning* keyvalue_partitioning + + def __init__(self): + _forbid_instantiation(self.__class__) + + cdef init(self, const shared_ptr[CPartitioning]& sp): + Partitioning.init(self, sp) + self.keyvalue_partitioning = sp.get() + self.wrapped = sp + self.partitioning = sp.get() + + def __reduce__(self): + dictionaries = self.dictionaries + if dictionaries: + dictionaries = dict(zip(self.schema.names, dictionaries)) + segment_encoding = _wrap_segment_encoding( + deref(self.keyvalue_partitioning).segment_encoding() + ) + return self.__class__, (self.schema, dictionaries, segment_encoding) + + @property + def dictionaries(self): + """ + The unique values for each partition field, if available. + + Those values are only available if the Partitioning object was + created through dataset discovery from a PartitioningFactory, or + if the dictionaries were manually specified in the constructor. + If no dictionary field is available, this returns an empty list. + """ + cdef vector[shared_ptr[CArray]] c_arrays + c_arrays = self.keyvalue_partitioning.dictionaries() + res = [] + for arr in c_arrays: + if arr.get() == nullptr: + # Partitioning object has not been created through + # inspected Factory + res.append(None) + else: + res.append(pyarrow_wrap_array(arr)) + return res + + +def _constructor_directory_partitioning_factory(*args): + return DirectoryPartitioning.discover(*args) + + +cdef class DirectoryPartitioning(KeyValuePartitioning): + """ + A Partitioning based on a specified Schema. + + The DirectoryPartitioning expects one segment in the file path for each + field in the schema (all fields are required to be present). + For example given schema the path "/2009/11" would + be parsed to ("year"_ == 2009 and "month"_ == 11). + + Parameters + ---------- + schema : Schema + The schema that describes the partitions present in the file path. + dictionaries : dict[str, Array] + If the type of any field of `schema` is a dictionary type, the + corresponding entry of `dictionaries` must be an array containing + every value which may be taken by the corresponding column or an + error will be raised in parsing. + segment_encoding : str, default "uri" + After splitting paths into segments, decode the segments. Valid + values are "uri" (URI-decode segments) and "none" (leave as-is). + + Returns + ------- + DirectoryPartitioning + + Examples + -------- + >>> from pyarrow.dataset import DirectoryPartitioning + >>> partitioning = DirectoryPartitioning( + ... pa.schema([("year", pa.int16()), ("month", pa.int8())])) + >>> print(partitioning.parse("/2009/11/")) + ((year == 2009) and (month == 11)) + """ + + cdef: + CDirectoryPartitioning* directory_partitioning + + def __init__(self, Schema schema not None, dictionaries=None, + segment_encoding="uri"): + cdef: + shared_ptr[CDirectoryPartitioning] c_partitioning + CKeyValuePartitioningOptions c_options + + c_options.segment_encoding = _get_segment_encoding(segment_encoding) + c_partitioning = make_shared[CDirectoryPartitioning]( + pyarrow_unwrap_schema(schema), + _partitioning_dictionaries(schema, dictionaries), + c_options, + ) + self.init( c_partitioning) + + cdef init(self, const shared_ptr[CPartitioning]& sp): + KeyValuePartitioning.init(self, sp) + self.directory_partitioning = sp.get() + + @staticmethod + def discover(field_names=None, infer_dictionary=False, + max_partition_dictionary_size=0, + schema=None, segment_encoding="uri"): + """ + Discover a DirectoryPartitioning. + + Parameters + ---------- + field_names : list of str + The names to associate with the values from the subdirectory names. + If schema is given, will be populated from the schema. + infer_dictionary : bool, default False + When inferring a schema for partition fields, yield dictionary + encoded types instead of plain types. This can be more efficient + when materializing virtual columns, and Expressions parsed by the + finished Partitioning will include dictionaries of all unique + inspected values for each field. + max_partition_dictionary_size : int, default 0 + Synonymous with infer_dictionary for backwards compatibility with + 1.0: setting this to -1 or None is equivalent to passing + infer_dictionary=True. + schema : Schema, default None + Use this schema instead of inferring a schema from partition + values. Partition values will be validated against this schema + before accumulation into the Partitioning's dictionary. + segment_encoding : str, default "uri" + After splitting paths into segments, decode the segments. Valid + values are "uri" (URI-decode segments) and "none" (leave as-is). + + Returns + ------- + PartitioningFactory + To be used in the FileSystemFactoryOptions. + """ + cdef: + CPartitioningFactoryOptions c_options + vector[c_string] c_field_names + + if max_partition_dictionary_size in {-1, None}: + infer_dictionary = True + elif max_partition_dictionary_size != 0: + raise NotImplementedError("max_partition_dictionary_size must be " + "0, -1, or None") + + if infer_dictionary: + c_options.infer_dictionary = True + + if schema: + c_options.schema = pyarrow_unwrap_schema(schema) + c_field_names = [tobytes(f.name) for f in schema] + elif not field_names: + raise ValueError( + "Neither field_names nor schema was passed; " + "cannot infer field_names") + else: + c_field_names = [tobytes(s) for s in field_names] + + c_options.segment_encoding = _get_segment_encoding(segment_encoding) + + return PartitioningFactory.wrap( + CDirectoryPartitioning.MakeFactory(c_field_names, c_options), + _constructor_directory_partitioning_factory, + (field_names, infer_dictionary, max_partition_dictionary_size, + schema, segment_encoding) + ) + + +def _constructor_hive_partitioning_factory(*args): + return HivePartitioning.discover(*args) + + +cdef class HivePartitioning(KeyValuePartitioning): + """ + A Partitioning for "/$key=$value/" nested directories as found in + Apache Hive. + + Multi-level, directory based partitioning scheme originating from + Apache Hive with all data files stored in the leaf directories. Data is + partitioned by static values of a particular column in the schema. + Partition keys are represented in the form $key=$value in directory names. + Field order is ignored, as are missing or unrecognized field names. + + For example, given schema, a possible + path would be "/year=2009/month=11/day=15". + + Parameters + ---------- + schema : Schema + The schema that describes the partitions present in the file path. + dictionaries : dict[str, Array] + If the type of any field of `schema` is a dictionary type, the + corresponding entry of `dictionaries` must be an array containing + every value which may be taken by the corresponding column or an + error will be raised in parsing. + null_fallback : str, default "__HIVE_DEFAULT_PARTITION__" + If any field is None then this fallback will be used as a label + segment_encoding : str, default "uri" + After splitting paths into segments, decode the segments. Valid + values are "uri" (URI-decode segments) and "none" (leave as-is). + + Returns + ------- + HivePartitioning + + Examples + -------- + >>> from pyarrow.dataset import HivePartitioning + >>> partitioning = HivePartitioning( + ... pa.schema([("year", pa.int16()), ("month", pa.int8())])) + >>> print(partitioning.parse("/year=2009/month=11/")) + ((year == 2009) and (month == 11)) + + """ + + cdef: + CHivePartitioning* hive_partitioning + + def __init__(self, + Schema schema not None, + dictionaries=None, + null_fallback="__HIVE_DEFAULT_PARTITION__", + segment_encoding="uri"): + + cdef: + shared_ptr[CHivePartitioning] c_partitioning + CHivePartitioningOptions c_options + + c_options.null_fallback = tobytes(null_fallback) + c_options.segment_encoding = _get_segment_encoding(segment_encoding) + + c_partitioning = make_shared[CHivePartitioning]( + pyarrow_unwrap_schema(schema), + _partitioning_dictionaries(schema, dictionaries), + c_options, + ) + self.init( c_partitioning) + + cdef init(self, const shared_ptr[CPartitioning]& sp): + KeyValuePartitioning.init(self, sp) + self.hive_partitioning = sp.get() + + def __reduce__(self): + dictionaries = self.dictionaries + if dictionaries: + dictionaries = dict(zip(self.schema.names, dictionaries)) + segment_encoding = _wrap_segment_encoding( + deref(self.keyvalue_partitioning).segment_encoding() + ) + null_fallback = frombytes(deref(self.hive_partitioning).null_fallback()) + return HivePartitioning, ( + self.schema, dictionaries, null_fallback, segment_encoding + ) + + @staticmethod + def discover(infer_dictionary=False, + max_partition_dictionary_size=0, + null_fallback="__HIVE_DEFAULT_PARTITION__", + schema=None, + segment_encoding="uri"): + """ + Discover a HivePartitioning. + + Parameters + ---------- + infer_dictionary : bool, default False + When inferring a schema for partition fields, yield dictionary + encoded types instead of plain. This can be more efficient when + materializing virtual columns, and Expressions parsed by the + finished Partitioning will include dictionaries of all unique + inspected values for each field. + max_partition_dictionary_size : int, default 0 + Synonymous with infer_dictionary for backwards compatibility with + 1.0: setting this to -1 or None is equivalent to passing + infer_dictionary=True. + null_fallback : str, default "__HIVE_DEFAULT_PARTITION__" + When inferring a schema for partition fields this value will be + replaced by null. The default is set to __HIVE_DEFAULT_PARTITION__ + for compatibility with Spark + schema : Schema, default None + Use this schema instead of inferring a schema from partition + values. Partition values will be validated against this schema + before accumulation into the Partitioning's dictionary. + segment_encoding : str, default "uri" + After splitting paths into segments, decode the segments. Valid + values are "uri" (URI-decode segments) and "none" (leave as-is). + + Returns + ------- + PartitioningFactory + To be used in the FileSystemFactoryOptions. + """ + cdef: + CHivePartitioningFactoryOptions c_options + + if max_partition_dictionary_size in {-1, None}: + infer_dictionary = True + elif max_partition_dictionary_size != 0: + raise NotImplementedError("max_partition_dictionary_size must be " + "0, -1, or None") + + if infer_dictionary: + c_options.infer_dictionary = True + + c_options.null_fallback = tobytes(null_fallback) + + if schema: + c_options.schema = pyarrow_unwrap_schema(schema) + + c_options.segment_encoding = _get_segment_encoding(segment_encoding) + + return PartitioningFactory.wrap( + CHivePartitioning.MakeFactory(c_options), + _constructor_hive_partitioning_factory, + (infer_dictionary, max_partition_dictionary_size, null_fallback, + schema, segment_encoding), + ) + + +def _constructor_filename_partitioning_factory(*args): + return FilenamePartitioning.discover(*args) + + +cdef class FilenamePartitioning(KeyValuePartitioning): + """ + A Partitioning based on a specified Schema. + + The FilenamePartitioning expects one segment in the file name for each + field in the schema (all fields are required to be present) separated + by '_'. For example given schema the name + ``"2009_11_"`` would be parsed to ("year" == 2009 and "month" == 11). + + Parameters + ---------- + schema : Schema + The schema that describes the partitions present in the file path. + dictionaries : dict[str, Array] + If the type of any field of `schema` is a dictionary type, the + corresponding entry of `dictionaries` must be an array containing + every value which may be taken by the corresponding column or an + error will be raised in parsing. + segment_encoding : str, default "uri" + After splitting paths into segments, decode the segments. Valid + values are "uri" (URI-decode segments) and "none" (leave as-is). + + Returns + ------- + FilenamePartitioning + + Examples + -------- + >>> from pyarrow.dataset import FilenamePartitioning + >>> partitioning = FilenamePartitioning( + ... pa.schema([("year", pa.int16()), ("month", pa.int8())])) + >>> print(partitioning.parse("2009_11_data.parquet")) + ((year == 2009) and (month == 11)) + """ + + cdef: + CFilenamePartitioning* filename_partitioning + + def __init__(self, Schema schema not None, dictionaries=None, + segment_encoding="uri"): + cdef: + shared_ptr[CFilenamePartitioning] c_partitioning + CKeyValuePartitioningOptions c_options + + c_options.segment_encoding = _get_segment_encoding(segment_encoding) + c_partitioning = make_shared[CFilenamePartitioning]( + pyarrow_unwrap_schema(schema), + _partitioning_dictionaries(schema, dictionaries), + c_options, + ) + self.init( c_partitioning) + + cdef init(self, const shared_ptr[CPartitioning]& sp): + KeyValuePartitioning.init(self, sp) + self.filename_partitioning = sp.get() + + @staticmethod + def discover(field_names=None, infer_dictionary=False, + schema=None, segment_encoding="uri"): + """ + Discover a FilenamePartitioning. + + Parameters + ---------- + field_names : list of str + The names to associate with the values from the subdirectory names. + If schema is given, will be populated from the schema. + infer_dictionary : bool, default False + When inferring a schema for partition fields, yield dictionary + encoded types instead of plain types. This can be more efficient + when materializing virtual columns, and Expressions parsed by the + finished Partitioning will include dictionaries of all unique + inspected values for each field. + schema : Schema, default None + Use this schema instead of inferring a schema from partition + values. Partition values will be validated against this schema + before accumulation into the Partitioning's dictionary. + segment_encoding : str, default "uri" + After splitting paths into segments, decode the segments. Valid + values are "uri" (URI-decode segments) and "none" (leave as-is). + + Returns + ------- + PartitioningFactory + To be used in the FileSystemFactoryOptions. + """ + cdef: + CPartitioningFactoryOptions c_options + vector[c_string] c_field_names + + if infer_dictionary: + c_options.infer_dictionary = True + + if schema: + c_options.schema = pyarrow_unwrap_schema(schema) + c_field_names = [tobytes(f.name) for f in schema] + elif not field_names: + raise TypeError( + "Neither field_names nor schema was passed; " + "cannot infer field_names") + else: + c_field_names = [tobytes(s) for s in field_names] + + c_options.segment_encoding = _get_segment_encoding(segment_encoding) + + return PartitioningFactory.wrap( + CFilenamePartitioning.MakeFactory(c_field_names, c_options), + _constructor_filename_partitioning_factory, + (field_names, infer_dictionary, schema, segment_encoding) + ) + + +cdef class DatasetFactory(_Weakrefable): + """ + DatasetFactory is used to create a Dataset, inspect the Schema + of the fragments contained in it, and declare a partitioning. + """ + + def __init__(self): + _forbid_instantiation(self.__class__) + + cdef init(self, const shared_ptr[CDatasetFactory]& sp): + self.wrapped = sp + self.factory = sp.get() + + @staticmethod + cdef wrap(const shared_ptr[CDatasetFactory]& sp): + cdef DatasetFactory self = \ + DatasetFactory.__new__(DatasetFactory) + self.init(sp) + return self + + cdef inline shared_ptr[CDatasetFactory] unwrap(self) nogil: + return self.wrapped + + @property + def root_partition(self): + return Expression.wrap(self.factory.root_partition()) + + @root_partition.setter + def root_partition(self, Expression expr): + check_status(self.factory.SetRootPartition(expr.unwrap())) + + def inspect_schemas(self): + cdef CResult[vector[shared_ptr[CSchema]]] result + cdef CInspectOptions options + with nogil: + result = self.factory.InspectSchemas(options) + + schemas = [] + for s in GetResultValue(result): + schemas.append(pyarrow_wrap_schema(s)) + return schemas + + def inspect(self): + """ + Inspect all data fragments and return a common Schema. + + Returns + ------- + Schema + """ + cdef: + CInspectOptions options + CResult[shared_ptr[CSchema]] result + with nogil: + result = self.factory.Inspect(options) + return pyarrow_wrap_schema(GetResultValue(result)) + + def finish(self, Schema schema=None): + """ + Create a Dataset using the inspected schema or an explicit schema + (if given). + + Parameters + ---------- + schema : Schema, default None + The schema to conform the source to. If None, the inspected + schema is used. + + Returns + ------- + Dataset + """ + cdef: + shared_ptr[CSchema] sp_schema + CResult[shared_ptr[CDataset]] result + + if schema is not None: + sp_schema = pyarrow_unwrap_schema(schema) + with nogil: + result = self.factory.FinishWithSchema(sp_schema) + else: + with nogil: + result = self.factory.Finish() + + return Dataset.wrap(GetResultValue(result)) + + +cdef class FileSystemFactoryOptions(_Weakrefable): + """ + Influences the discovery of filesystem paths. + + Parameters + ---------- + partition_base_dir : str, optional + For the purposes of applying the partitioning, paths will be + stripped of the partition_base_dir. Files not matching the + partition_base_dir prefix will be skipped for partitioning discovery. + The ignored files will still be part of the Dataset, but will not + have partition information. + partitioning : Partitioning/PartitioningFactory, optional + Apply the Partitioning to every discovered Fragment. See Partitioning or + PartitioningFactory documentation. + exclude_invalid_files : bool, optional (default True) + If True, invalid files will be excluded (file format specific check). + This will incur IO for each files in a serial and single threaded + fashion. Disabling this feature will skip the IO, but unsupported + files may be present in the Dataset (resulting in an error at scan + time). + selector_ignore_prefixes : list, optional + When discovering from a Selector (and not from an explicit file list), + ignore files and directories matching any of these prefixes. + By default this is ['.', '_']. + """ + + cdef: + CFileSystemFactoryOptions options + + __slots__ = () # avoid mistakingly creating attributes + + def __init__(self, partition_base_dir=None, partitioning=None, + exclude_invalid_files=None, + list selector_ignore_prefixes=None): + if isinstance(partitioning, PartitioningFactory): + self.partitioning_factory = partitioning + elif isinstance(partitioning, Partitioning): + self.partitioning = partitioning + + if partition_base_dir is not None: + self.partition_base_dir = partition_base_dir + if exclude_invalid_files is not None: + self.exclude_invalid_files = exclude_invalid_files + if selector_ignore_prefixes is not None: + self.selector_ignore_prefixes = selector_ignore_prefixes + + cdef inline CFileSystemFactoryOptions unwrap(self): + return self.options + + @property + def partitioning(self): + """Partitioning to apply to discovered files. + + NOTE: setting this property will overwrite partitioning_factory. + """ + c_partitioning = self.options.partitioning.partitioning() + if c_partitioning.get() == nullptr: + return None + return Partitioning.wrap(c_partitioning) + + @partitioning.setter + def partitioning(self, Partitioning value): + self.options.partitioning = ( value).unwrap() + + @property + def partitioning_factory(self): + """PartitioningFactory to apply to discovered files and + discover a Partitioning. + + NOTE: setting this property will overwrite partitioning. + """ + c_factory = self.options.partitioning.factory() + if c_factory.get() == nullptr: + return None + return PartitioningFactory.wrap(c_factory, None, None) + + @partitioning_factory.setter + def partitioning_factory(self, PartitioningFactory value): + self.options.partitioning = ( value).unwrap() + + @property + def partition_base_dir(self): + """ + Base directory to strip paths before applying the partitioning. + """ + return frombytes(self.options.partition_base_dir) + + @partition_base_dir.setter + def partition_base_dir(self, value): + self.options.partition_base_dir = tobytes(value) + + @property + def exclude_invalid_files(self): + """Whether to exclude invalid files.""" + return self.options.exclude_invalid_files + + @exclude_invalid_files.setter + def exclude_invalid_files(self, bint value): + self.options.exclude_invalid_files = value + + @property + def selector_ignore_prefixes(self): + """ + List of prefixes. Files matching one of those prefixes will be + ignored by the discovery process. + """ + return [frombytes(p) for p in self.options.selector_ignore_prefixes] + + @selector_ignore_prefixes.setter + def selector_ignore_prefixes(self, values): + self.options.selector_ignore_prefixes = [tobytes(v) for v in values] + + +cdef vector[CFileInfo] unwrap_finfos(finfos): + cdef vector[CFileInfo] o_vect + for fi in finfos: + o_vect.push_back(( fi).unwrap()) + return o_vect + + +cdef class FileSystemDatasetFactory(DatasetFactory): + """ + Create a DatasetFactory from a list of paths with schema inspection. + + Parameters + ---------- + filesystem : pyarrow.fs.FileSystem + Filesystem to discover. + paths_or_selector : pyarrow.fs.FileSelector or list of path-likes + Either a Selector object or a list of path-like objects. + format : FileFormat + Currently only ParquetFileFormat and IpcFileFormat are supported. + options : FileSystemFactoryOptions, optional + Various flags influencing the discovery of filesystem paths. + """ + + cdef: + CFileSystemDatasetFactory* filesystem_factory + + def __init__(self, FileSystem filesystem not None, paths_or_selector, + FileFormat format not None, + FileSystemFactoryOptions options=None): + cdef: + vector[c_string] paths + vector[CFileInfo] finfos + CFileSelector c_selector + CResult[shared_ptr[CDatasetFactory]] result + shared_ptr[CFileSystem] c_filesystem + shared_ptr[CFileFormat] c_format + CFileSystemFactoryOptions c_options + + options = options or FileSystemFactoryOptions() + c_options = options.unwrap() + c_filesystem = filesystem.unwrap() + c_format = format.unwrap() + + if isinstance(paths_or_selector, FileSelector): + with nogil: + c_selector = ( paths_or_selector).selector + result = CFileSystemDatasetFactory.MakeFromSelector( + c_filesystem, + c_selector, + c_format, + c_options + ) + elif isinstance(paths_or_selector, (list, tuple)): + if len(paths_or_selector) > 0 and isinstance(paths_or_selector[0], FileInfo): + finfos = unwrap_finfos(paths_or_selector) + with nogil: + result = CFileSystemDatasetFactory.MakeFromFileInfos( + c_filesystem, + finfos, + c_format, + c_options + ) + else: + paths = [tobytes(s) for s in paths_or_selector] + with nogil: + result = CFileSystemDatasetFactory.MakeFromPaths( + c_filesystem, + paths, + c_format, + c_options + ) + else: + raise TypeError('Must pass either paths or a FileSelector, but ' + 'passed {}'.format(type(paths_or_selector))) + + self.init(GetResultValue(result)) + + cdef init(self, shared_ptr[CDatasetFactory]& sp): + DatasetFactory.init(self, sp) + self.filesystem_factory = sp.get() + + +cdef class UnionDatasetFactory(DatasetFactory): + """ + Provides a way to inspect/discover a Dataset's expected schema before + materialization. + + Parameters + ---------- + factories : list of DatasetFactory + """ + + cdef: + CUnionDatasetFactory* union_factory + + def __init__(self, list factories): + cdef: + DatasetFactory factory + vector[shared_ptr[CDatasetFactory]] c_factories + for factory in factories: + c_factories.push_back(factory.unwrap()) + self.init(GetResultValue(CUnionDatasetFactory.Make(c_factories))) + + cdef init(self, const shared_ptr[CDatasetFactory]& sp): + DatasetFactory.init(self, sp) + self.union_factory = sp.get() + + +cdef class RecordBatchIterator(_Weakrefable): + """An iterator over a sequence of record batches.""" + cdef: + # An object that must be kept alive with the iterator. + object iterator_owner + # Iterator is a non-POD type and Cython uses offsetof, leading + # to a compiler warning unless wrapped like so + SharedPtrNoGIL[CRecordBatchIterator] iterator + + def __init__(self): + _forbid_instantiation(self.__class__, subclasses_instead=False) + + @staticmethod + cdef wrap(object owner, CRecordBatchIterator iterator): + cdef RecordBatchIterator self = \ + RecordBatchIterator.__new__(RecordBatchIterator) + self.iterator_owner = owner + self.iterator = make_shared[CRecordBatchIterator](move(iterator)) + return self + + cdef inline shared_ptr[CRecordBatchIterator] unwrap(self) nogil: + return self.iterator + + def __iter__(self): + return self + + def __next__(self): + cdef shared_ptr[CRecordBatch] record_batch + with nogil: + record_batch = GetResultValue(move(self.iterator.get().Next())) + if record_batch == NULL: + raise StopIteration + return pyarrow_wrap_batch(record_batch) + + +class TaggedRecordBatch(collections.namedtuple( + "TaggedRecordBatch", ["record_batch", "fragment"])): + """ + A combination of a record batch and the fragment it came from. + + Parameters + ---------- + record_batch : RecordBatch + The record batch. + fragment : Fragment + Fragment of the record batch. + """ + + +cdef class TaggedRecordBatchIterator(_Weakrefable): + """An iterator over a sequence of record batches with fragments.""" + cdef: + object iterator_owner + SharedPtrNoGIL[CTaggedRecordBatchIterator] iterator + + def __init__(self): + _forbid_instantiation(self.__class__, subclasses_instead=False) + + @staticmethod + cdef wrap(object owner, CTaggedRecordBatchIterator iterator): + cdef TaggedRecordBatchIterator self = \ + TaggedRecordBatchIterator.__new__(TaggedRecordBatchIterator) + self.iterator_owner = owner + self.iterator = make_shared[CTaggedRecordBatchIterator]( + move(iterator)) + return self + + def __iter__(self): + return self + + def __next__(self): + cdef CTaggedRecordBatch batch + with nogil: + batch = GetResultValue(move(self.iterator.get().Next())) + if batch.record_batch == NULL: + raise StopIteration + return TaggedRecordBatch( + record_batch=pyarrow_wrap_batch(batch.record_batch), + fragment=Fragment.wrap(batch.fragment)) + + +cdef void _populate_builder(const shared_ptr[CScannerBuilder]& ptr, + object columns=None, Expression filter=None, + int batch_size=_DEFAULT_BATCH_SIZE, + int batch_readahead=_DEFAULT_BATCH_READAHEAD, + int fragment_readahead=_DEFAULT_FRAGMENT_READAHEAD, + bint use_threads=True, MemoryPool memory_pool=None, + FragmentScanOptions fragment_scan_options=None)\ + except *: + cdef: + CScannerBuilder *builder + vector[CExpression] c_exprs + + builder = ptr.get() + + check_status(builder.Filter(_bind( + filter, pyarrow_wrap_schema(builder.schema())))) + + if columns is not None: + if pa_substrait and isinstance(columns, pa_substrait.BoundExpressions): + columns = columns.expressions + + if isinstance(columns, dict): + for expr in columns.values(): + if not isinstance(expr, Expression): + raise TypeError( + "Expected an Expression for a 'column' dictionary " + "value, got {} instead".format(type(expr)) + ) + c_exprs.push_back(( expr).unwrap()) + + check_status( + builder.Project(c_exprs, [tobytes(c) for c in columns.keys()]) + ) + elif isinstance(columns, list): + check_status(builder.ProjectColumns([tobytes(c) for c in columns])) + else: + raise ValueError( + "Expected a list or a dict for 'columns', " + "got {} instead.".format(type(columns)) + ) + + check_status(builder.BatchSize(batch_size)) + check_status(builder.BatchReadahead(batch_readahead)) + check_status(builder.FragmentReadahead(fragment_readahead)) + check_status(builder.UseThreads(use_threads)) + check_status(builder.Pool(maybe_unbox_memory_pool(memory_pool))) + if fragment_scan_options: + check_status( + builder.FragmentScanOptions(fragment_scan_options.wrapped)) + + +cdef class Scanner(_Weakrefable): + """A materialized scan operation with context and options bound. + + A scanner is the class that glues the scan tasks, data fragments and data + sources together. + """ + + def __init__(self): + _forbid_instantiation(self.__class__) + + cdef void init(self, const shared_ptr[CScanner]& sp): + self.wrapped = sp + self.scanner = sp.get() + + @staticmethod + cdef wrap(const shared_ptr[CScanner]& sp): + cdef Scanner self = Scanner.__new__(Scanner) + self.init(sp) + return self + + cdef inline shared_ptr[CScanner] unwrap(self): + return self.wrapped + + @staticmethod + cdef shared_ptr[CScanOptions] _make_scan_options(Dataset dataset, dict py_scanoptions) except *: + cdef: + shared_ptr[CScannerBuilder] builder = make_shared[CScannerBuilder](dataset.unwrap()) + + py_scanoptions = dataset._scanner_options(py_scanoptions) + + # Need to explicitly expand the arguments as Cython doesn't support + # keyword expansion in cdef functions. + _populate_builder( + builder, + columns=py_scanoptions.get("columns"), + filter=py_scanoptions.get("filter"), + batch_size=py_scanoptions.get("batch_size", _DEFAULT_BATCH_SIZE), + batch_readahead=py_scanoptions.get( + "batch_readahead", _DEFAULT_BATCH_READAHEAD), + fragment_readahead=py_scanoptions.get( + "fragment_readahead", _DEFAULT_FRAGMENT_READAHEAD), + use_threads=py_scanoptions.get("use_threads", True), + memory_pool=py_scanoptions.get("memory_pool"), + fragment_scan_options=py_scanoptions.get("fragment_scan_options")) + + return GetResultValue(deref(builder).GetScanOptions()) + + @staticmethod + def from_dataset(Dataset dataset not None, *, + object columns=None, + object filter=None, + int batch_size=_DEFAULT_BATCH_SIZE, + int batch_readahead=_DEFAULT_BATCH_READAHEAD, + int fragment_readahead=_DEFAULT_FRAGMENT_READAHEAD, + FragmentScanOptions fragment_scan_options=None, + bint use_threads=True, MemoryPool memory_pool=None): + """ + Create Scanner from Dataset, + + Parameters + ---------- + dataset : Dataset + Dataset to scan. + columns : list[str] or dict[str, Expression], default None + The columns to project. This can be a list of column names to + include (order and duplicates will be preserved), or a dictionary + with {new_column_name: expression} values for more advanced + projections. + + The list of columns or expressions may use the special fields + `__batch_index` (the index of the batch within the fragment), + `__fragment_index` (the index of the fragment within the dataset), + `__last_in_fragment` (whether the batch is last in fragment), and + `__filename` (the name of the source file or a description of the + source fragment). + + The columns will be passed down to Datasets and corresponding data + fragments to avoid loading, copying, and deserializing columns + that will not be required further down the compute chain. + By default all of the available columns are projected. Raises + an exception if any of the referenced column names does not exist + in the dataset's Schema. + filter : Expression, default None + Scan will return only the rows matching the filter. + If possible the predicate will be pushed down to exploit the + partition information or internal metadata found in the data + source, e.g. Parquet statistics. Otherwise filters the loaded + RecordBatches before yielding them. + batch_size : int, default 131_072 + The maximum row count for scanned record batches. If scanned + record batches are overflowing memory then this method can be + called to reduce their size. + batch_readahead : int, default 16 + The number of batches to read ahead in a file. This might not work + for all file formats. Increasing this number will increase + RAM usage but could also improve IO utilization. + fragment_readahead : int, default 4 + The number of files to read ahead. Increasing this number will increase + RAM usage but could also improve IO utilization. + fragment_scan_options : FragmentScanOptions, default None + Options specific to a particular scan and fragment type, which + can change between different scans of the same dataset. + use_threads : bool, default True + If enabled, then maximum parallelism will be used determined by + the number of available CPU cores. + memory_pool : MemoryPool, default None + For memory allocations, if required. If not specified, uses the + default pool. + """ + cdef: + shared_ptr[CScanOptions] options + shared_ptr[CScannerBuilder] builder + shared_ptr[CScanner] scanner + + options = Scanner._make_scan_options( + dataset, + dict(columns=columns, filter=filter, batch_size=batch_size, + batch_readahead=batch_readahead, + fragment_readahead=fragment_readahead, use_threads=use_threads, + memory_pool=memory_pool, fragment_scan_options=fragment_scan_options) + ) + builder = make_shared[CScannerBuilder](dataset.unwrap(), options) + scanner = GetResultValue(builder.get().Finish()) + return Scanner.wrap(scanner) + + @staticmethod + def from_fragment(Fragment fragment not None, *, Schema schema=None, + object columns=None, Expression filter=None, + int batch_size=_DEFAULT_BATCH_SIZE, + int batch_readahead=_DEFAULT_BATCH_READAHEAD, + int fragment_readahead=_DEFAULT_FRAGMENT_READAHEAD, + FragmentScanOptions fragment_scan_options=None, + bint use_threads=True, MemoryPool memory_pool=None): + """ + Create Scanner from Fragment, + + Parameters + ---------- + fragment : Fragment + fragment to scan. + schema : Schema, optional + The schema of the fragment. + columns : list[str] or dict[str, Expression], default None + The columns to project. This can be a list of column names to + include (order and duplicates will be preserved), or a dictionary + with {new_column_name: expression} values for more advanced + projections. + + The list of columns or expressions may use the special fields + `__batch_index` (the index of the batch within the fragment), + `__fragment_index` (the index of the fragment within the dataset), + `__last_in_fragment` (whether the batch is last in fragment), and + `__filename` (the name of the source file or a description of the + source fragment). + + The columns will be passed down to Datasets and corresponding data + fragments to avoid loading, copying, and deserializing columns + that will not be required further down the compute chain. + By default all of the available columns are projected. Raises + an exception if any of the referenced column names does not exist + in the dataset's Schema. + filter : Expression, default None + Scan will return only the rows matching the filter. + If possible the predicate will be pushed down to exploit the + partition information or internal metadata found in the data + source, e.g. Parquet statistics. Otherwise filters the loaded + RecordBatches before yielding them. + batch_size : int, default 131_072 + The maximum row count for scanned record batches. If scanned + record batches are overflowing memory then this method can be + called to reduce their size. + batch_readahead : int, default 16 + The number of batches to read ahead in a file. This might not work + for all file formats. Increasing this number will increase + RAM usage but could also improve IO utilization. + fragment_readahead : int, default 4 + The number of files to read ahead. Increasing this number will increase + RAM usage but could also improve IO utilization. + fragment_scan_options : FragmentScanOptions, default None + Options specific to a particular scan and fragment type, which + can change between different scans of the same dataset. + use_threads : bool, default True + If enabled, then maximum parallelism will be used determined by + the number of available CPU cores. + memory_pool : MemoryPool, default None + For memory allocations, if required. If not specified, uses the + default pool. + """ + cdef: + shared_ptr[CScanOptions] options = make_shared[CScanOptions]() + shared_ptr[CScannerBuilder] builder + shared_ptr[CScanner] scanner + + schema = schema or fragment.physical_schema + + builder = make_shared[CScannerBuilder](pyarrow_unwrap_schema(schema), + fragment.unwrap(), options) + _populate_builder(builder, columns=columns, filter=filter, + batch_size=batch_size, batch_readahead=batch_readahead, + fragment_readahead=fragment_readahead, + use_threads=use_threads, + memory_pool=memory_pool, + fragment_scan_options=fragment_scan_options) + + scanner = GetResultValue(builder.get().Finish()) + return Scanner.wrap(scanner) + + @staticmethod + def from_batches(source, *, Schema schema=None, object columns=None, + Expression filter=None, int batch_size=_DEFAULT_BATCH_SIZE, + int batch_readahead=_DEFAULT_BATCH_READAHEAD, + int fragment_readahead=_DEFAULT_FRAGMENT_READAHEAD, + FragmentScanOptions fragment_scan_options=None, + bint use_threads=True, MemoryPool memory_pool=None): + """ + Create a Scanner from an iterator of batches. + + This creates a scanner which can be used only once. It is + intended to support writing a dataset (which takes a scanner) + from a source which can be read only once (e.g. a + RecordBatchReader or generator). + + Parameters + ---------- + source : Iterator or Arrow-compatible stream object + The iterator of Batches. This can be a pyarrow RecordBatchReader, + any object that implements the Arrow PyCapsule Protocol for + streams, or an actual Python iterator of RecordBatches. + schema : Schema + The schema of the batches (required when passing a Python + iterator). + columns : list[str] or dict[str, Expression], default None + The columns to project. This can be a list of column names to + include (order and duplicates will be preserved), or a dictionary + with {new_column_name: expression} values for more advanced + projections. + + The list of columns or expressions may use the special fields + `__batch_index` (the index of the batch within the fragment), + `__fragment_index` (the index of the fragment within the dataset), + `__last_in_fragment` (whether the batch is last in fragment), and + `__filename` (the name of the source file or a description of the + source fragment). + + The columns will be passed down to Datasets and corresponding data + fragments to avoid loading, copying, and deserializing columns + that will not be required further down the compute chain. + By default all of the available columns are projected. Raises + an exception if any of the referenced column names does not exist + in the dataset's Schema. + filter : Expression, default None + Scan will return only the rows matching the filter. + If possible the predicate will be pushed down to exploit the + partition information or internal metadata found in the data + source, e.g. Parquet statistics. Otherwise filters the loaded + RecordBatches before yielding them. + batch_size : int, default 131_072 + The maximum row count for scanned record batches. If scanned + record batches are overflowing memory then this method can be + called to reduce their size. + batch_readahead : int, default 16 + The number of batches to read ahead in a file. This might not work + for all file formats. Increasing this number will increase + RAM usage but could also improve IO utilization. + fragment_readahead : int, default 4 + The number of files to read ahead. Increasing this number will increase + RAM usage but could also improve IO utilization. + fragment_scan_options : FragmentScanOptions, default None + Options specific to a particular scan and fragment type, which + can change between different scans of the same dataset. + use_threads : bool, default True + If enabled, then maximum parallelism will be used determined by + the number of available CPU cores. + memory_pool : MemoryPool, default None + For memory allocations, if required. If not specified, uses the + default pool. + """ + cdef: + shared_ptr[CScannerBuilder] builder + shared_ptr[CScanner] scanner + RecordBatchReader reader + if isinstance(source, pa.ipc.RecordBatchReader): + if schema: + raise ValueError('Cannot specify a schema when providing ' + 'a RecordBatchReader') + reader = source + elif hasattr(source, "__arrow_c_stream__"): + if schema: + raise ValueError( + 'Cannot specify a schema when providing an object ' + 'implementing the Arrow PyCapsule Protocol') + reader = pa.ipc.RecordBatchReader.from_stream(source) + elif _is_iterable(source): + if schema is None: + raise ValueError('Must provide schema to construct scanner ' + 'from an iterable') + reader = pa.ipc.RecordBatchReader.from_batches(schema, source) + else: + raise TypeError('Expected a RecordBatchReader or an iterable of ' + 'batches instead of the given type: ' + + type(source).__name__) + builder = CScannerBuilder.FromRecordBatchReader(reader.reader) + _populate_builder(builder, columns=columns, filter=filter, + batch_size=batch_size, batch_readahead=batch_readahead, + fragment_readahead=fragment_readahead, use_threads=use_threads, + memory_pool=memory_pool, + fragment_scan_options=fragment_scan_options) + scanner = GetResultValue(builder.get().Finish()) + return Scanner.wrap(scanner) + + @property + def dataset_schema(self): + """The schema with which batches will be read from fragments.""" + return pyarrow_wrap_schema( + self.scanner.options().get().dataset_schema) + + @property + def projected_schema(self): + """ + The materialized schema of the data, accounting for projections. + + This is the schema of any data returned from the scanner. + """ + return pyarrow_wrap_schema( + self.scanner.options().get().projected_schema) + + def to_batches(self): + """ + Consume a Scanner in record batches. + + Returns + ------- + record_batches : iterator of RecordBatch + """ + def _iterator(batch_iter): + for batch in batch_iter: + yield batch.record_batch + # Don't make ourselves a generator so errors are raised immediately + return _iterator(self.scan_batches()) + + def scan_batches(self): + """ + Consume a Scanner in record batches with corresponding fragments. + + Returns + ------- + record_batches : iterator of TaggedRecordBatch + """ + cdef CTaggedRecordBatchIterator iterator + with nogil: + iterator = move(GetResultValue(self.scanner.ScanBatches())) + # Don't make ourselves a generator so errors are raised immediately + return TaggedRecordBatchIterator.wrap(self, move(iterator)) + + def to_table(self): + """ + Convert a Scanner into a Table. + + Use this convenience utility with care. This will serially materialize + the Scan result in memory before creating the Table. + + Returns + ------- + Table + """ + cdef CResult[shared_ptr[CTable]] result + + with nogil: + result = self.scanner.ToTable() + + return pyarrow_wrap_table(GetResultValue(result)) + + def take(self, object indices): + """ + Select rows of data by index. + + Will only consume as many batches of the underlying dataset as + needed. Otherwise, this is equivalent to + ``to_table().take(indices)``. + + Parameters + ---------- + indices : Array or array-like + indices of rows to select in the dataset. + + Returns + ------- + Table + """ + cdef CResult[shared_ptr[CTable]] result + cdef shared_ptr[CArray] c_indices + + if not isinstance(indices, pa.Array): + indices = pa.array(indices) + c_indices = pyarrow_unwrap_array(indices) + + with nogil: + result = self.scanner.TakeRows(deref(c_indices)) + return pyarrow_wrap_table(GetResultValue(result)) + + def head(self, int num_rows): + """ + Load the first N rows of the dataset. + + Parameters + ---------- + num_rows : int + The number of rows to load. + + Returns + ------- + Table + """ + cdef CResult[shared_ptr[CTable]] result + with nogil: + result = self.scanner.Head(num_rows) + return pyarrow_wrap_table(GetResultValue(result)) + + def count_rows(self): + """ + Count rows matching the scanner filter. + + Returns + ------- + count : int + """ + cdef CResult[int64_t] result + with nogil: + result = self.scanner.CountRows() + return GetResultValue(result) + + def to_reader(self): + """Consume this scanner as a RecordBatchReader. + + Returns + ------- + RecordBatchReader + """ + cdef RecordBatchReader reader + reader = RecordBatchReader.__new__(RecordBatchReader) + reader.reader = GetResultValue(self.scanner.ToRecordBatchReader()) + return reader + + +def get_partition_keys(Expression partition_expression): + """ + Extract partition keys (equality constraints between a field and a scalar) + from an expression as a dict mapping the field's name to its value. + + NB: All expressions yielded by a HivePartitioning or DirectoryPartitioning + will be conjunctions of equality conditions and are accessible through this + function. Other subexpressions will be ignored. + + Parameters + ---------- + partition_expression : pyarrow.dataset.Expression + + Returns + ------- + dict + + Examples + -------- + + For example, an expression of + + is converted to {'part': 'A', 'year': 2016} + """ + cdef: + CExpression expr = partition_expression.unwrap() + pair[CFieldRef, CDatum] ref_val + + out = {} + for ref_val in GetResultValue(CExtractKnownFieldValues(expr)).map: + assert ref_val.first.name() != nullptr + assert ref_val.second.kind() == DatumType_SCALAR + val = pyarrow_wrap_scalar(ref_val.second.scalar()) + out[frombytes(deref(ref_val.first.name()))] = val.as_py() + return out + + +cdef class WrittenFile(_Weakrefable): + """ + Metadata information about files written as + part of a dataset write operation + + Parameters + ---------- + path : str + Path to the file. + metadata : pyarrow.parquet.FileMetaData, optional + For Parquet files, the Parquet file metadata. + size : int + The size of the file in bytes. + """ + + def __init__(self, path, metadata, size): + self.path = path + self.metadata = metadata + self.size = size + + +cdef void _filesystemdataset_write_visitor( + dict visit_args, + CFileWriter* file_writer): + cdef: + str path + str base_dir + WrittenFile written_file + FileFormat file_format + + path = frombytes(deref(file_writer).destination().path) + base_dir = frombytes(visit_args['base_dir']) + file_format = FileFormat.wrap(file_writer.format()) + written_file = file_format._finish_write(path, base_dir, file_writer) + visit_args['file_visitor'](written_file) + + +def _filesystemdataset_write( + Scanner data not None, + object base_dir not None, + str basename_template not None, + FileSystem filesystem not None, + Partitioning partitioning not None, + FileWriteOptions file_options not None, + int max_partitions, + object file_visitor, + str existing_data_behavior not None, + int max_open_files, + int max_rows_per_file, + int min_rows_per_group, + int max_rows_per_group, + bool create_dir +): + """ + CFileSystemDataset.Write wrapper + """ + cdef: + CFileSystemDatasetWriteOptions c_options + shared_ptr[CScanner] c_scanner + dict visit_args + + c_options.file_write_options = file_options.unwrap() + c_options.filesystem = filesystem.unwrap() + c_options.base_dir = tobytes(_stringify_path(base_dir)) + c_options.partitioning = partitioning.unwrap() + c_options.max_partitions = max_partitions + c_options.max_open_files = max_open_files + c_options.max_rows_per_file = max_rows_per_file + c_options.max_rows_per_group = max_rows_per_group + c_options.min_rows_per_group = min_rows_per_group + c_options.basename_template = tobytes(basename_template) + if existing_data_behavior == 'error': + c_options.existing_data_behavior = ExistingDataBehavior_ERROR + elif existing_data_behavior == 'overwrite_or_ignore': + c_options.existing_data_behavior =\ + ExistingDataBehavior_OVERWRITE_OR_IGNORE + elif existing_data_behavior == 'delete_matching': + c_options.existing_data_behavior = ExistingDataBehavior_DELETE_MATCHING + else: + raise ValueError( + ("existing_data_behavior must be one of 'error', ", + "'overwrite_or_ignore' or 'delete_matching'") + ) + c_options.create_dir = create_dir + + if file_visitor is not None: + visit_args = {'base_dir': c_options.base_dir, + 'file_visitor': file_visitor} + # Need to use post_finish because parquet metadata is not available + # until after Finish has been called + c_options.writer_post_finish = BindFunction[cb_writer_finish_internal]( + &_filesystemdataset_write_visitor, visit_args) + + c_scanner = data.unwrap() + with nogil: + check_status(CFileSystemDataset.Write(c_options, c_scanner)) + + +cdef class _ScanNodeOptions(ExecNodeOptions): + + def _set_options(self, Dataset dataset, dict scan_options): + cdef: + shared_ptr[CScanOptions] c_scan_options + bint require_sequenced_output=False + + c_scan_options = Scanner._make_scan_options(dataset, scan_options) + + require_sequenced_output=scan_options.get("require_sequenced_output", False) + + self.wrapped.reset( + new CScanNodeOptions(dataset.unwrap(), c_scan_options, require_sequenced_output) + ) + + +class ScanNodeOptions(_ScanNodeOptions): + """ + A Source node which yields batches from a Dataset scan. + + This is the option class for the "scan" node factory. + + This node is capable of applying pushdown projections or filters + to the file readers which reduce the amount of data that needs to + be read (if supported by the file format). But note that this does not + construct associated filter or project nodes to perform the final + filtering or projection. Rather, you may supply the same filter + expression or projection to the scan node that you also supply + to the filter or project node. + + Yielded batches will be augmented with fragment/batch indices to + enable stable ordering for simple ExecPlans. + + Parameters + ---------- + dataset : pyarrow.dataset.Dataset + The table which acts as the data source. + **kwargs : dict, optional + Scan options. See `Scanner.from_dataset` for possible arguments. + require_sequenced_output : bool, default False + Assert implicit ordering on data. + """ + + def __init__(self, Dataset dataset, **kwargs): + self._set_options(dataset, kwargs) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/_dataset_orc.cpython-312-x86_64-linux-gnu.so b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/_dataset_orc.cpython-312-x86_64-linux-gnu.so new file mode 100644 index 0000000000000000000000000000000000000000..207cf48a4e4c4d19bd955a5d1d385a1e08cc6691 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/_dataset_orc.cpython-312-x86_64-linux-gnu.so differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/_dataset_orc.pyx b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/_dataset_orc.pyx new file mode 100644 index 0000000000000000000000000000000000000000..a8cce3362225adcfd7e70b51e521f26d43d9a102 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/_dataset_orc.pyx @@ -0,0 +1,51 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# cython: language_level = 3 + +"""Dataset support for ORC file format.""" + +from pyarrow.lib cimport * +from pyarrow.includes.libarrow cimport * +from pyarrow.includes.libarrow_dataset cimport * + +from pyarrow._dataset cimport FileFormat + + +cdef class OrcFileFormat(FileFormat): + + def __init__(self): + self.init(shared_ptr[CFileFormat](new COrcFileFormat())) + + def equals(self, OrcFileFormat other): + """ + Parameters + ---------- + other : pyarrow.dataset.OrcFileFormat + + Returns + ------- + True + """ + return True + + @property + def default_extname(self): + return "orc" + + def __reduce__(self): + return OrcFileFormat, tuple() diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/_dataset_parquet_encryption.pyx b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/_dataset_parquet_encryption.pyx new file mode 100644 index 0000000000000000000000000000000000000000..c8f5e5b01bf81f32d641d70341fe74bf6bfbbc80 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/_dataset_parquet_encryption.pyx @@ -0,0 +1,178 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# cython: language_level = 3 + +"""Dataset support for Parquet encryption.""" + +from pyarrow.includes.libarrow_dataset_parquet cimport * +from pyarrow._parquet_encryption cimport * +from pyarrow._dataset_parquet cimport ParquetFragmentScanOptions, ParquetFileWriteOptions + + +cdef class ParquetEncryptionConfig(_Weakrefable): + """ + Core configuration class encapsulating parameters for high-level encryption + within the Parquet framework. + + The ParquetEncryptionConfig class serves as a bridge for passing encryption-related + parameters to the appropriate components within the Parquet library. It maintains references + to objects that define the encryption strategy, Key Management Service (KMS) configuration, + and specific encryption configurations for Parquet data. + + Parameters + ---------- + crypto_factory : pyarrow.parquet.encryption.CryptoFactory + Shared pointer to a `CryptoFactory` object. The `CryptoFactory` is responsible for + creating cryptographic components, such as encryptors and decryptors. + kms_connection_config : pyarrow.parquet.encryption.KmsConnectionConfig + Shared pointer to a `KmsConnectionConfig` object. This object holds the configuration + parameters necessary for connecting to a Key Management Service (KMS). + encryption_config : pyarrow.parquet.encryption.EncryptionConfiguration + Shared pointer to an `EncryptionConfiguration` object. This object defines specific + encryption settings for Parquet data, including the keys assigned to different columns. + + Raises + ------ + ValueError + Raised if `encryption_config` is None. + """ + cdef: + shared_ptr[CParquetEncryptionConfig] c_config + + # Avoid mistakenly creating attributes + __slots__ = () + + def __cinit__(self, CryptoFactory crypto_factory, KmsConnectionConfig kms_connection_config, + EncryptionConfiguration encryption_config): + + cdef shared_ptr[CEncryptionConfiguration] c_encryption_config + + if crypto_factory is None: + raise ValueError("crypto_factory cannot be None") + + if kms_connection_config is None: + raise ValueError("kms_connection_config cannot be None") + + if encryption_config is None: + raise ValueError("encryption_config cannot be None") + + self.c_config.reset(new CParquetEncryptionConfig()) + + c_encryption_config = pyarrow_unwrap_encryptionconfig( + encryption_config) + + self.c_config.get().crypto_factory = pyarrow_unwrap_cryptofactory(crypto_factory) + self.c_config.get().kms_connection_config = pyarrow_unwrap_kmsconnectionconfig( + kms_connection_config) + self.c_config.get().encryption_config = c_encryption_config + + @staticmethod + cdef wrap(shared_ptr[CParquetEncryptionConfig] c_config): + cdef ParquetEncryptionConfig python_config = ParquetEncryptionConfig.__new__(ParquetEncryptionConfig) + python_config.c_config = c_config + return python_config + + cdef shared_ptr[CParquetEncryptionConfig] unwrap(self): + return self.c_config + + +cdef class ParquetDecryptionConfig(_Weakrefable): + """ + Core configuration class encapsulating parameters for high-level decryption + within the Parquet framework. + + ParquetDecryptionConfig is designed to pass decryption-related parameters to + the appropriate decryption components within the Parquet library. It holds references to + objects that define the decryption strategy, Key Management Service (KMS) configuration, + and specific decryption configurations for reading encrypted Parquet data. + + Parameters + ---------- + crypto_factory : pyarrow.parquet.encryption.CryptoFactory + Shared pointer to a `CryptoFactory` object, pivotal in creating cryptographic + components for the decryption process. + kms_connection_config : pyarrow.parquet.encryption.KmsConnectionConfig + Shared pointer to a `KmsConnectionConfig` object, containing parameters necessary + for connecting to a Key Management Service (KMS) during decryption. + decryption_config : pyarrow.parquet.encryption.DecryptionConfiguration + Shared pointer to a `DecryptionConfiguration` object, specifying decryption settings + for reading encrypted Parquet data. + + Raises + ------ + ValueError + Raised if `decryption_config` is None. + """ + + cdef: + shared_ptr[CParquetDecryptionConfig] c_config + + # Avoid mistakingly creating attributes + __slots__ = () + + def __cinit__(self, CryptoFactory crypto_factory, KmsConnectionConfig kms_connection_config, + DecryptionConfiguration decryption_config): + + cdef shared_ptr[CDecryptionConfiguration] c_decryption_config + + if decryption_config is None: + raise ValueError( + "decryption_config cannot be None") + + self.c_config.reset(new CParquetDecryptionConfig()) + + c_decryption_config = pyarrow_unwrap_decryptionconfig( + decryption_config) + + self.c_config.get().crypto_factory = pyarrow_unwrap_cryptofactory(crypto_factory) + self.c_config.get().kms_connection_config = pyarrow_unwrap_kmsconnectionconfig( + kms_connection_config) + self.c_config.get().decryption_config = c_decryption_config + + @staticmethod + cdef wrap(shared_ptr[CParquetDecryptionConfig] c_config): + cdef ParquetDecryptionConfig python_config = ParquetDecryptionConfig.__new__(ParquetDecryptionConfig) + python_config.c_config = c_config + return python_config + + cdef shared_ptr[CParquetDecryptionConfig] unwrap(self): + return self.c_config + + +def set_encryption_config( + ParquetFileWriteOptions opts not None, + ParquetEncryptionConfig config not None +): + cdef shared_ptr[CParquetEncryptionConfig] c_config = config.unwrap() + opts.parquet_options.parquet_encryption_config = c_config + + +def set_decryption_properties( + ParquetFragmentScanOptions opts not None, + FileDecryptionProperties config not None +): + cdef CReaderProperties* reader_props = opts.reader_properties() + reader_props.file_decryption_properties(config.unwrap()) + + +def set_decryption_config( + ParquetFragmentScanOptions opts not None, + ParquetDecryptionConfig config not None +): + cdef shared_ptr[CParquetDecryptionConfig] c_config = config.unwrap() + opts.parquet_options.parquet_decryption_config = c_config diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/_dlpack.pxi b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/_dlpack.pxi new file mode 100644 index 0000000000000000000000000000000000000000..c2f4cff64069195ad70f2ea271a842dfd166058c --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/_dlpack.pxi @@ -0,0 +1,46 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +cimport cpython +from cpython.pycapsule cimport PyCapsule_New + + +cdef void dlpack_pycapsule_deleter(object dltensor) noexcept: + cdef DLManagedTensor* dlm_tensor + cdef PyObject* err_type + cdef PyObject* err_value + cdef PyObject* err_traceback + + # Do nothing if the capsule has been consumed + if cpython.PyCapsule_IsValid(dltensor, "used_dltensor"): + return + + # An exception may be in-flight, we must save it in case + # we create another one + cpython.PyErr_Fetch(&err_type, &err_value, &err_traceback) + + dlm_tensor = cpython.PyCapsule_GetPointer(dltensor, 'dltensor') + if dlm_tensor == NULL: + cpython.PyErr_WriteUnraisable(dltensor) + # The deleter can be NULL if there is no way for the caller + # to provide a reasonable destructor + elif dlm_tensor.deleter: + dlm_tensor.deleter(dlm_tensor) + assert (not cpython.PyErr_Occurred()) + + # Set the error indicator from err_type, err_value, err_traceback + cpython.PyErr_Restore(err_type, err_value, err_traceback) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/_flight.pyx b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/_flight.pyx new file mode 100644 index 0000000000000000000000000000000000000000..9a2341d6948e55aa745211df38c2147635e6b056 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/_flight.pyx @@ -0,0 +1,3264 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# cython: language_level = 3 + +import collections +import enum +import re +import time +import warnings +import weakref + +from cython.operator cimport dereference as deref +from cython.operator cimport postincrement +from libcpp cimport bool as c_bool + +from pyarrow.lib cimport * +from pyarrow.lib import (ArrowCancelled, ArrowException, ArrowInvalid, + SignalStopHandler) +from pyarrow.lib import as_buffer, frombytes, timestamp, tobytes +from pyarrow.includes.libarrow_flight cimport * +from pyarrow.ipc import _get_legacy_format_default, _ReadPandasMixin +import pyarrow.lib as lib + + +cdef CFlightCallOptions DEFAULT_CALL_OPTIONS + + +cdef int check_flight_status(const CStatus& status) except -1 nogil: + cdef shared_ptr[FlightStatusDetail] detail + + if status.ok(): + return 0 + + detail = FlightStatusDetail.UnwrapStatus(status) + if detail: + with gil: + message = frombytes(status.message(), safe=True) + detail_msg = detail.get().extra_info() + if detail.get().code() == CFlightStatusInternal: + raise FlightInternalError(message, detail_msg) + elif detail.get().code() == CFlightStatusFailed: + message = _munge_grpc_python_error(message) + raise FlightServerError(message, detail_msg) + elif detail.get().code() == CFlightStatusTimedOut: + raise FlightTimedOutError(message, detail_msg) + elif detail.get().code() == CFlightStatusCancelled: + raise FlightCancelledError(message, detail_msg) + elif detail.get().code() == CFlightStatusUnauthenticated: + raise FlightUnauthenticatedError(message, detail_msg) + elif detail.get().code() == CFlightStatusUnauthorized: + raise FlightUnauthorizedError(message, detail_msg) + elif detail.get().code() == CFlightStatusUnavailable: + raise FlightUnavailableError(message, detail_msg) + + size_detail = FlightWriteSizeStatusDetail.UnwrapStatus(status) + if size_detail: + with gil: + message = frombytes(status.message(), safe=True) + raise FlightWriteSizeExceededError( + message, + size_detail.get().limit(), size_detail.get().actual()) + + return check_status(status) + + +_FLIGHT_SERVER_ERROR_REGEX = re.compile( + r'Flight RPC failed with message: (.*). Detail: ' + r'Python exception: (.*)', + re.DOTALL +) + + +def _munge_grpc_python_error(message): + m = _FLIGHT_SERVER_ERROR_REGEX.match(message) + if m: + return ('Flight RPC failed with Python exception \"{}: {}\"' + .format(m.group(2), m.group(1))) + else: + return message + + +cdef IpcWriteOptions _get_options(options): + return _get_legacy_format_default( + use_legacy_format=None, options=options) + + +cdef class FlightCallOptions(_Weakrefable): + """RPC-layer options for a Flight call.""" + + cdef: + CFlightCallOptions options + + def __init__(self, timeout=None, write_options=None, headers=None, + IpcReadOptions read_options=None): + """Create call options. + + Parameters + ---------- + timeout : float, None + A timeout for the call, in seconds. None means that the + timeout defaults to an implementation-specific value. + write_options : pyarrow.ipc.IpcWriteOptions, optional + IPC write options. The default options can be controlled + by environment variables (see pyarrow.ipc). + headers : List[Tuple[str, str]], optional + A list of arbitrary headers as key, value tuples + read_options : pyarrow.ipc.IpcReadOptions, optional + Serialization options for reading IPC format. + """ + cdef IpcWriteOptions c_write_options + + if timeout is not None: + self.options.timeout = CTimeoutDuration(timeout) + if write_options is not None: + c_write_options = _get_options(write_options) + self.options.write_options = c_write_options.c_options + if read_options is not None: + if not isinstance(read_options, IpcReadOptions): + raise TypeError("expected IpcReadOptions, got {}" + .format(type(read_options))) + self.options.read_options = read_options.c_options + if headers is not None: + self.options.headers = headers + + @staticmethod + cdef CFlightCallOptions* unwrap(obj): + if not obj: + return &DEFAULT_CALL_OPTIONS + elif isinstance(obj, FlightCallOptions): + return &(( obj).options) + raise TypeError("Expected a FlightCallOptions object, not " + "'{}'".format(type(obj))) + + +_CertKeyPair = collections.namedtuple('_CertKeyPair', ['cert', 'key']) + + +class CertKeyPair(_CertKeyPair): + """A TLS certificate and key for use in Flight.""" + + +cdef class FlightError(Exception): + """ + The base class for Flight-specific errors. + + A server may raise this class or one of its subclasses to provide + a more detailed error to clients. + + Parameters + ---------- + message : str, optional + The error message. + extra_info : bytes, optional + Extra binary error details that were provided by the + server/will be sent to the client. + + Attributes + ---------- + extra_info : bytes + Extra binary error details that were provided by the + server/will be sent to the client. + """ + + cdef dict __dict__ + + def __init__(self, message='', extra_info=b''): + super().__init__(message) + self.extra_info = tobytes(extra_info) + + cdef CStatus to_status(self): + message = tobytes("Flight error: {}".format(str(self))) + return CStatus_UnknownError(message) + + +cdef class FlightInternalError(FlightError, ArrowException): + """An error internal to the Flight server occurred.""" + + cdef CStatus to_status(self): + return MakeFlightError(CFlightStatusInternal, + tobytes(str(self)), self.extra_info) + + +cdef class FlightTimedOutError(FlightError, ArrowException): + """The Flight RPC call timed out.""" + + cdef CStatus to_status(self): + return MakeFlightError(CFlightStatusTimedOut, + tobytes(str(self)), self.extra_info) + + +cdef class FlightCancelledError(FlightError, ArrowCancelled): + """The operation was cancelled.""" + + cdef CStatus to_status(self): + return MakeFlightError(CFlightStatusCancelled, tobytes(str(self)), + self.extra_info) + + +cdef class FlightServerError(FlightError, ArrowException): + """A server error occurred.""" + + cdef CStatus to_status(self): + return MakeFlightError(CFlightStatusFailed, tobytes(str(self)), + self.extra_info) + + +cdef class FlightUnauthenticatedError(FlightError, ArrowException): + """The client is not authenticated.""" + + cdef CStatus to_status(self): + return MakeFlightError( + CFlightStatusUnauthenticated, tobytes(str(self)), self.extra_info) + + +cdef class FlightUnauthorizedError(FlightError, ArrowException): + """The client is not authorized to perform the given operation.""" + + cdef CStatus to_status(self): + return MakeFlightError(CFlightStatusUnauthorized, tobytes(str(self)), + self.extra_info) + + +cdef class FlightUnavailableError(FlightError, ArrowException): + """The server is not reachable or available.""" + + cdef CStatus to_status(self): + return MakeFlightError(CFlightStatusUnavailable, tobytes(str(self)), + self.extra_info) + + +class FlightWriteSizeExceededError(ArrowInvalid): + """A write operation exceeded the client-configured limit.""" + + def __init__(self, message, limit, actual): + super().__init__(message) + self.limit = limit + self.actual = actual + + +cdef class Action(_Weakrefable): + """An action executable on a Flight service.""" + cdef: + CAction action + + def __init__(self, action_type, buf): + """Create an action from a type and a buffer. + + Parameters + ---------- + action_type : bytes or str + buf : Buffer or bytes-like object + """ + self.action.type = tobytes(action_type) + self.action.body = pyarrow_unwrap_buffer(as_buffer(buf)) + + @property + def type(self): + """The action type.""" + return frombytes(self.action.type) + + @property + def body(self): + """The action body (arguments for the action).""" + return pyarrow_wrap_buffer(self.action.body) + + @staticmethod + cdef CAction unwrap(action) except *: + if not isinstance(action, Action): + raise TypeError("Must provide Action, not '{}'".format( + type(action))) + return ( action).action + + def serialize(self): + """Get the wire-format representation of this type. + + Useful when interoperating with non-Flight systems (e.g. REST + services) that may want to return Flight types. + + """ + return GetResultValue(self.action.SerializeToString()) + + @classmethod + def deserialize(cls, serialized): + """Parse the wire-format representation of this type. + + Useful when interoperating with non-Flight systems (e.g. REST + services) that may want to return Flight types. + + """ + cdef Action action = Action.__new__(Action) + action.action = GetResultValue( + CAction.Deserialize(tobytes(serialized))) + return action + + def __eq__(self, Action other): + return self.action == other.action + + def __repr__(self): + return (f"") + + +_ActionType = collections.namedtuple('_ActionType', ['type', 'description']) + + +class ActionType(_ActionType): + """A type of action that is executable on a Flight service.""" + + def make_action(self, buf): + """Create an Action with this type. + + Parameters + ---------- + buf : obj + An Arrow buffer or Python bytes or bytes-like object. + """ + return Action(self.type, buf) + + +cdef class Result(_Weakrefable): + """A result from executing an Action.""" + cdef: + unique_ptr[CFlightResult] result + + def __init__(self, buf): + """Create a new result. + + Parameters + ---------- + buf : Buffer or bytes-like object + """ + self.result.reset(new CFlightResult()) + self.result.get().body = pyarrow_unwrap_buffer(as_buffer(buf)) + + @property + def body(self): + """Get the Buffer containing the result.""" + return pyarrow_wrap_buffer(self.result.get().body) + + def serialize(self): + """Get the wire-format representation of this type. + + Useful when interoperating with non-Flight systems (e.g. REST + services) that may want to return Flight types. + + """ + return GetResultValue(self.result.get().SerializeToString()) + + @classmethod + def deserialize(cls, serialized): + """Parse the wire-format representation of this type. + + Useful when interoperating with non-Flight systems (e.g. REST + services) that may want to return Flight types. + + """ + cdef Result result = Result.__new__(Result) + result.result.reset(new CFlightResult(GetResultValue( + CFlightResult.Deserialize(tobytes(serialized))))) + return result + + def __eq__(self, Result other): + return deref(self.result.get()) == deref(other.result.get()) + + def __repr__(self): + return f"" + + +cdef class BasicAuth(_Weakrefable): + """A container for basic auth.""" + cdef: + unique_ptr[CBasicAuth] basic_auth + + def __init__(self, username=None, password=None): + """Create a new basic auth object. + + Parameters + ---------- + username : string + password : string + """ + self.basic_auth.reset(new CBasicAuth()) + if username: + self.basic_auth.get().username = tobytes(username) + if password: + self.basic_auth.get().password = tobytes(password) + + @property + def username(self): + """Get the username.""" + return self.basic_auth.get().username + + @property + def password(self): + """Get the password.""" + return self.basic_auth.get().password + + @staticmethod + def deserialize(serialized): + auth = BasicAuth() + auth.basic_auth.reset(new CBasicAuth(GetResultValue( + CBasicAuth.Deserialize(tobytes(serialized))))) + return auth + + def serialize(self): + return GetResultValue(self.basic_auth.get().SerializeToString()) + + def __eq__(self, BasicAuth other): + return deref(self.basic_auth.get()) == deref(other.basic_auth.get()) + + def __repr__(self): + return (f"") + + +class DescriptorType(enum.Enum): + """ + The type of a FlightDescriptor. + + Attributes + ---------- + + UNKNOWN + An unknown descriptor type. + + PATH + A Flight stream represented by a path. + + CMD + A Flight stream represented by an application-defined command. + + """ + + UNKNOWN = 0 + PATH = 1 + CMD = 2 + + +class FlightMethod(enum.Enum): + """The implemented methods in Flight.""" + + INVALID = 0 + HANDSHAKE = 1 + LIST_FLIGHTS = 2 + GET_FLIGHT_INFO = 3 + GET_SCHEMA = 4 + DO_GET = 5 + DO_PUT = 6 + DO_ACTION = 7 + LIST_ACTIONS = 8 + DO_EXCHANGE = 9 + + +cdef wrap_flight_method(CFlightMethod method): + if method == CFlightMethodHandshake: + return FlightMethod.HANDSHAKE + elif method == CFlightMethodListFlights: + return FlightMethod.LIST_FLIGHTS + elif method == CFlightMethodGetFlightInfo: + return FlightMethod.GET_FLIGHT_INFO + elif method == CFlightMethodGetSchema: + return FlightMethod.GET_SCHEMA + elif method == CFlightMethodDoGet: + return FlightMethod.DO_GET + elif method == CFlightMethodDoPut: + return FlightMethod.DO_PUT + elif method == CFlightMethodDoAction: + return FlightMethod.DO_ACTION + elif method == CFlightMethodListActions: + return FlightMethod.LIST_ACTIONS + elif method == CFlightMethodDoExchange: + return FlightMethod.DO_EXCHANGE + return FlightMethod.INVALID + + +cdef class FlightDescriptor(_Weakrefable): + """A description of a data stream available from a Flight service.""" + cdef: + CFlightDescriptor descriptor + + def __init__(self): + raise TypeError("Do not call {}'s constructor directly, use " + "`pyarrow.flight.FlightDescriptor.for_{path,command}` " + "function instead." + .format(self.__class__.__name__)) + + @staticmethod + def for_path(*path): + """Create a FlightDescriptor for a resource path.""" + cdef FlightDescriptor result = \ + FlightDescriptor.__new__(FlightDescriptor) + result.descriptor.type = CDescriptorTypePath + result.descriptor.path = [tobytes(p) for p in path] + return result + + @staticmethod + def for_command(command): + """Create a FlightDescriptor for an opaque command.""" + cdef FlightDescriptor result = \ + FlightDescriptor.__new__(FlightDescriptor) + result.descriptor.type = CDescriptorTypeCmd + result.descriptor.cmd = tobytes(command) + return result + + @property + def descriptor_type(self): + """Get the type of this descriptor.""" + if self.descriptor.type == CDescriptorTypeUnknown: + return DescriptorType.UNKNOWN + elif self.descriptor.type == CDescriptorTypePath: + return DescriptorType.PATH + elif self.descriptor.type == CDescriptorTypeCmd: + return DescriptorType.CMD + raise RuntimeError("Invalid descriptor type!") + + @property + def command(self): + """Get the command for this descriptor.""" + if self.descriptor_type != DescriptorType.CMD: + return None + return self.descriptor.cmd + + @property + def path(self): + """Get the path for this descriptor.""" + if self.descriptor_type != DescriptorType.PATH: + return None + return self.descriptor.path + + def __repr__(self): + if self.descriptor_type == DescriptorType.PATH: + return f"" + elif self.descriptor_type == DescriptorType.CMD: + return f"" + else: + return "" + + @staticmethod + cdef CFlightDescriptor unwrap(descriptor) except *: + if not isinstance(descriptor, FlightDescriptor): + raise TypeError("Must provide a FlightDescriptor, not '{}'".format( + type(descriptor))) + return ( descriptor).descriptor + + def serialize(self): + """Get the wire-format representation of this type. + + Useful when interoperating with non-Flight systems (e.g. REST + services) that may want to return Flight types. + + """ + return GetResultValue(self.descriptor.SerializeToString()) + + @classmethod + def deserialize(cls, serialized): + """Parse the wire-format representation of this type. + + Useful when interoperating with non-Flight systems (e.g. REST + services) that may want to return Flight types. + + """ + cdef FlightDescriptor descriptor = \ + FlightDescriptor.__new__(FlightDescriptor) + descriptor.descriptor = GetResultValue( + CFlightDescriptor.Deserialize(tobytes(serialized))) + return descriptor + + def __eq__(self, FlightDescriptor other): + return self.descriptor == other.descriptor + + +cdef class Ticket(_Weakrefable): + """A ticket for requesting a Flight stream.""" + + cdef: + CTicket c_ticket + + def __init__(self, ticket): + self.c_ticket.ticket = tobytes(ticket) + + @property + def ticket(self): + return self.c_ticket.ticket + + def serialize(self): + """Get the wire-format representation of this type. + + Useful when interoperating with non-Flight systems (e.g. REST + services) that may want to return Flight types. + + """ + return GetResultValue(self.c_ticket.SerializeToString()) + + @classmethod + def deserialize(cls, serialized): + """Parse the wire-format representation of this type. + + Useful when interoperating with non-Flight systems (e.g. REST + services) that may want to return Flight types. + + """ + cdef Ticket ticket = Ticket.__new__(Ticket) + ticket.c_ticket = GetResultValue( + CTicket.Deserialize(tobytes(serialized))) + return ticket + + def __eq__(self, Ticket other): + return self.c_ticket == other.c_ticket + + def __repr__(self): + return f"" + + +cdef class Location(_Weakrefable): + """The location of a Flight service.""" + cdef: + CLocation location + + def __init__(self, uri): + check_flight_status(CLocation.Parse(tobytes(uri)).Value(&self.location)) + + def __repr__(self): + return f'' + + @property + def uri(self): + return self.location.ToString() + + def equals(self, Location other): + return self == other + + def __eq__(self, other): + if not isinstance(other, Location): + return NotImplemented + return self.location.Equals(( other).location) + + @staticmethod + def for_grpc_tcp(host, port): + """Create a Location for a TCP-based gRPC service.""" + cdef: + c_string c_host = tobytes(host) + int c_port = port + Location result = Location.__new__(Location) + check_flight_status( + CLocation.ForGrpcTcp(c_host, c_port).Value(&result.location)) + return result + + @staticmethod + def for_grpc_tls(host, port): + """Create a Location for a TLS-based gRPC service.""" + cdef: + c_string c_host = tobytes(host) + int c_port = port + Location result = Location.__new__(Location) + check_flight_status( + CLocation.ForGrpcTls(c_host, c_port).Value(&result.location)) + return result + + @staticmethod + def for_grpc_unix(path): + """Create a Location for a domain socket-based gRPC service.""" + cdef: + c_string c_path = tobytes(path) + Location result = Location.__new__(Location) + check_flight_status(CLocation.ForGrpcUnix(c_path).Value(&result.location)) + return result + + @staticmethod + cdef Location wrap(CLocation location): + cdef Location result = Location.__new__(Location) + result.location = location + return result + + @staticmethod + cdef CLocation unwrap(object location) except *: + cdef CLocation c_location + if isinstance(location, str): + check_flight_status( + CLocation.Parse(tobytes(location)).Value(&c_location)) + return c_location + elif not isinstance(location, Location): + raise TypeError("Must provide a Location, not '{}'".format( + type(location))) + return ( location).location + + +cdef class FlightEndpoint(_Weakrefable): + """A Flight stream, along with the ticket and locations to access it.""" + cdef: + CFlightEndpoint endpoint + + def __init__(self, ticket, locations, expiration_time=None, app_metadata=""): + """Create a FlightEndpoint from a ticket and list of locations. + + Parameters + ---------- + ticket : Ticket or bytes + the ticket needed to access this flight + locations : list of string URIs + locations where this flight is available + expiration_time : TimestampScalar, default None + Expiration time of this stream. If present, clients may assume + they can retry DoGet requests. Otherwise, clients should avoid + retrying DoGet requests. + app_metadata : bytes or str, default "" + Application-defined opaque metadata. + + Raises + ------ + ArrowException + If one of the location URIs is not a valid URI. + """ + cdef: + CLocation c_location + + if isinstance(ticket, Ticket): + self.endpoint.ticket.ticket = tobytes(ticket.ticket) + elif isinstance(ticket, (str, bytes)): + self.endpoint.ticket.ticket = tobytes(ticket) + else: + raise TypeError("Argument ticket must be a Ticket instance, string or bytes, " + "not '{}'".format(type(ticket))) + + for location in locations: + if isinstance(location, Location): + c_location = ( location).location + elif isinstance(location, (str, bytes)): + c_location = CLocation() + check_flight_status( + CLocation.Parse(tobytes(location)).Value(&c_location)) + else: + raise TypeError("Argument locations must contain Location instances, strings or bytes, " + "not '{}'".format(type(location))) + self.endpoint.locations.push_back(c_location) + + if expiration_time is not None: + if isinstance(expiration_time, lib.TimestampScalar): + self.endpoint.expiration_time = TimePoint_from_ns( + expiration_time.cast(timestamp("ns")).value) + else: + raise TypeError("Argument expiration_time must be a TimestampScalar, " + "not '{}'".format(type(expiration_time))) + + if not isinstance(app_metadata, (str, bytes)): + raise TypeError("Argument app_metadata must be a string or bytes, " + "not '{}'".format(type(app_metadata))) + self.endpoint.app_metadata = tobytes(app_metadata) + + @property + def ticket(self): + """Get the ticket in this endpoint.""" + return Ticket(self.endpoint.ticket.ticket) + + @property + def locations(self): + """Get locations where this flight is available.""" + return [Location.wrap(location) + for location in self.endpoint.locations] + + @property + def expiration_time(self): + """Get the expiration time of this stream. + + If present, clients may assume they can retry DoGet requests. + Otherwise, clients should avoid retrying DoGet requests. + + """ + cdef: + int64_t time_since_epoch + if self.endpoint.expiration_time.has_value(): + time_since_epoch = TimePoint_to_ns(self.endpoint.expiration_time.value()) + return lib.scalar(time_since_epoch, timestamp("ns", "UTC")) + return None + + @property + def app_metadata(self): + """Get application-defined opaque metadata.""" + return self.endpoint.app_metadata + + def serialize(self): + """Get the wire-format representation of this type. + + Useful when interoperating with non-Flight systems (e.g. REST + services) that may want to return Flight types. + + """ + return GetResultValue(self.endpoint.SerializeToString()) + + @classmethod + def deserialize(cls, serialized): + """Parse the wire-format representation of this type. + + Useful when interoperating with non-Flight systems (e.g. REST + services) that may want to return Flight types. + + """ + cdef FlightEndpoint endpoint = FlightEndpoint.__new__(FlightEndpoint) + endpoint.endpoint = GetResultValue( + CFlightEndpoint.Deserialize(tobytes(serialized))) + return endpoint + + def __repr__(self): + return (f"") + + def __eq__(self, FlightEndpoint other): + return self.endpoint == other.endpoint + + +cdef class SchemaResult(_Weakrefable): + """The serialized schema returned from a GetSchema request.""" + cdef: + unique_ptr[CSchemaResult] result + + def __init__(self, Schema schema): + """Create a SchemaResult from a schema. + + Parameters + ---------- + schema: Schema + the schema of the data in this flight. + """ + cdef: + shared_ptr[CSchema] c_schema = pyarrow_unwrap_schema(schema) + check_flight_status(CreateSchemaResult(c_schema, &self.result)) + + @property + def schema(self): + """The schema of the data in this flight.""" + cdef: + shared_ptr[CSchema] schema + CDictionaryMemo dummy_memo + + check_flight_status(self.result.get().GetSchema(&dummy_memo).Value(&schema)) + return pyarrow_wrap_schema(schema) + + def serialize(self): + """Get the wire-format representation of this type. + + Useful when interoperating with non-Flight systems (e.g. REST + services) that may want to return Flight types. + + """ + return GetResultValue(self.result.get().SerializeToString()) + + @classmethod + def deserialize(cls, serialized): + """Parse the wire-format representation of this type. + + Useful when interoperating with non-Flight systems (e.g. REST + services) that may want to return Flight types. + + """ + cdef SchemaResult result = SchemaResult.__new__(SchemaResult) + result.result.reset(new CSchemaResult(GetResultValue( + CSchemaResult.Deserialize(tobytes(serialized))))) + return result + + def __eq__(self, SchemaResult other): + return deref(self.result.get()) == deref(other.result.get()) + + def __repr__(self): + return f"" + + +cdef class FlightInfo(_Weakrefable): + """A description of a Flight stream.""" + cdef: + unique_ptr[CFlightInfo] info + + @staticmethod + cdef wrap(CFlightInfo c_info): + cdef FlightInfo obj = FlightInfo.__new__(FlightInfo) + obj.info.reset(new CFlightInfo(move(c_info))) + return obj + + def __init__(self, Schema schema, FlightDescriptor descriptor, endpoints, + total_records=None, total_bytes=None, ordered=False, app_metadata=""): + """Create a FlightInfo object from a schema, descriptor, and endpoints. + + Parameters + ---------- + schema : Schema + the schema of the data in this flight. + descriptor : FlightDescriptor + the descriptor for this flight. + endpoints : list of FlightEndpoint + a list of endpoints where this flight is available. + total_records : int, default None + the total records in this flight, -1 or None if unknown. + total_bytes : int, default None + the total bytes in this flight, -1 or None if unknown. + ordered : boolean, default False + Whether endpoints are in the same order as the data. + app_metadata : bytes or str, default "" + Application-defined opaque metadata. + """ + cdef: + shared_ptr[CSchema] c_schema = pyarrow_unwrap_schema(schema) + vector[CFlightEndpoint] c_endpoints + + for endpoint in endpoints: + if isinstance(endpoint, FlightEndpoint): + c_endpoints.push_back(( endpoint).endpoint) + else: + raise TypeError('Endpoint {} is not instance of' + ' FlightEndpoint'.format(endpoint)) + + check_flight_status(CreateFlightInfo(c_schema, + descriptor.descriptor, + c_endpoints, + total_records if total_records is not None else -1, + total_bytes if total_bytes is not None else -1, + ordered, + tobytes(app_metadata), &self.info)) + + @property + def total_records(self): + """The total record count of this flight, or -1 if unknown.""" + return self.info.get().total_records() + + @property + def total_bytes(self): + """The size in bytes of the data in this flight, or -1 if unknown.""" + return self.info.get().total_bytes() + + @property + def ordered(self): + """Whether endpoints are in the same order as the data.""" + return self.info.get().ordered() + + @property + def app_metadata(self): + """ + Application-defined opaque metadata. + + There is no inherent or required relationship between this and the + app_metadata fields in the FlightEndpoints or resulting FlightData + messages. Since this metadata is application-defined, a given + application could define there to be a relationship, but there is + none required by the spec. + + """ + return self.info.get().app_metadata() + + @property + def schema(self): + """The schema of the data in this flight.""" + cdef: + shared_ptr[CSchema] schema + CDictionaryMemo dummy_memo + + check_flight_status(self.info.get().GetSchema(&dummy_memo).Value(&schema)) + return pyarrow_wrap_schema(schema) + + @property + def descriptor(self): + """The descriptor of the data in this flight.""" + cdef FlightDescriptor result = \ + FlightDescriptor.__new__(FlightDescriptor) + result.descriptor = self.info.get().descriptor() + return result + + @property + def endpoints(self): + """The endpoints where this flight is available.""" + # TODO: get Cython to iterate over reference directly + cdef: + vector[CFlightEndpoint] endpoints = self.info.get().endpoints() + FlightEndpoint py_endpoint + + result = [] + for endpoint in endpoints: + py_endpoint = FlightEndpoint.__new__(FlightEndpoint) + py_endpoint.endpoint = endpoint + result.append(py_endpoint) + return result + + def serialize(self): + """Get the wire-format representation of this type. + + Useful when interoperating with non-Flight systems (e.g. REST + services) that may want to return Flight types. + + """ + return GetResultValue(self.info.get().SerializeToString()) + + @classmethod + def deserialize(cls, serialized): + """Parse the wire-format representation of this type. + + Useful when interoperating with non-Flight systems (e.g. REST + services) that may want to return Flight types. + + """ + cdef FlightInfo info = FlightInfo.__new__(FlightInfo) + info.info = move(GetResultValue( + CFlightInfo.Deserialize(tobytes(serialized)))) + return info + + def __eq__(self, FlightInfo other): + return deref(self.info.get()) == deref(other.info.get()) + + def __repr__(self): + return (f"") + + +cdef class FlightStreamChunk(_Weakrefable): + """A RecordBatch with application metadata on the side.""" + cdef: + CFlightStreamChunk chunk + + @property + def data(self): + if self.chunk.data == NULL: + return None + return pyarrow_wrap_batch(self.chunk.data) + + @property + def app_metadata(self): + if self.chunk.app_metadata == NULL: + return None + return pyarrow_wrap_buffer(self.chunk.app_metadata) + + def __iter__(self): + return iter((self.data, self.app_metadata)) + + def __repr__(self): + return "".format( + self.chunk.data != NULL, self.chunk.app_metadata != NULL) + + +cdef class _MetadataRecordBatchReader(_Weakrefable, _ReadPandasMixin): + """A reader for Flight streams.""" + + # Needs to be separate class so the "real" class can subclass the + # pure-Python mixin class + + cdef dict __dict__ + cdef shared_ptr[CMetadataRecordBatchReader] reader + + def __iter__(self): + return self + + def __next__(self): + return self.read_chunk() + + @property + def schema(self): + """Get the schema for this reader.""" + cdef shared_ptr[CSchema] c_schema + with nogil: + check_flight_status(self.reader.get().GetSchema().Value(&c_schema)) + return pyarrow_wrap_schema(c_schema) + + def read_all(self): + """Read the entire contents of the stream as a Table.""" + cdef: + shared_ptr[CTable] c_table + with nogil: + check_flight_status(self.reader.get().ToTable().Value(&c_table)) + return pyarrow_wrap_table(c_table) + + def read_chunk(self): + """Read the next FlightStreamChunk along with any metadata. + + Returns + ------- + chunk : FlightStreamChunk + The next FlightStreamChunk in the stream. + + Raises + ------ + StopIteration + when the stream is finished + """ + cdef: + FlightStreamChunk chunk = FlightStreamChunk() + + with nogil: + check_flight_status(self.reader.get().Next().Value(&chunk.chunk)) + + if chunk.chunk.data == NULL and chunk.chunk.app_metadata == NULL: + raise StopIteration + + return chunk + + def to_reader(self): + """Convert this reader into a regular RecordBatchReader. + + This may fail if the schema cannot be read from the remote end. + + Returns + ------- + RecordBatchReader + """ + cdef RecordBatchReader reader + reader = RecordBatchReader.__new__(RecordBatchReader) + with nogil: + reader.reader = GetResultValue(MakeRecordBatchReader(self.reader)) + + return reader + + +cdef class MetadataRecordBatchReader(_MetadataRecordBatchReader): + """The base class for readers for Flight streams. + + See Also + -------- + FlightStreamReader + """ + + +cdef class FlightStreamReader(MetadataRecordBatchReader): + """A reader that can also be canceled.""" + + def cancel(self): + """Cancel the read operation.""" + with nogil: + ( self.reader.get()).Cancel() + + def read_all(self): + """Read the entire contents of the stream as a Table.""" + cdef: + shared_ptr[CTable] c_table + CStopToken stop_token + with SignalStopHandler() as stop_handler: + stop_token = ( stop_handler.stop_token).stop_token + with nogil: + check_flight_status( + ( self.reader.get()) + .ToTableWithStopToken(stop_token).Value(&c_table)) + return pyarrow_wrap_table(c_table) + + +cdef class MetadataRecordBatchWriter(_CRecordBatchWriter): + """A RecordBatchWriter that also allows writing application metadata. + + This class is a context manager; on exit, close() will be called. + """ + + cdef CMetadataRecordBatchWriter* _writer(self) nogil: + return self.writer.get() + + def begin(self, schema: Schema, options=None): + """Prepare to write data to this stream with the given schema.""" + cdef: + shared_ptr[CSchema] c_schema = pyarrow_unwrap_schema(schema) + CIpcWriteOptions c_options = _get_options(options).c_options + with nogil: + check_flight_status(self._writer().Begin(c_schema, c_options)) + + def write_metadata(self, buf): + """Write Flight metadata by itself.""" + cdef shared_ptr[CBuffer] c_buf = pyarrow_unwrap_buffer(as_buffer(buf)) + with nogil: + check_flight_status( + self._writer().WriteMetadata(c_buf)) + + def write_batch(self, RecordBatch batch): + """ + Write RecordBatch to stream. + + Parameters + ---------- + batch : RecordBatch + """ + cdef: + shared_ptr[const CKeyValueMetadata] custom_metadata + + # Override superclass method to use check_flight_status so we + # can generate FlightWriteSizeExceededError. We don't do this + # for write_table as callers who intend to handle the error + # and retry with a smaller batch should be working with + # individual batches to have control. + + with nogil: + check_flight_status( + self._writer().WriteRecordBatch(deref(batch.batch), custom_metadata)) + + def write_table(self, Table table, max_chunksize=None, **kwargs): + """ + Write Table to stream in (contiguous) RecordBatch objects. + + Parameters + ---------- + table : Table + max_chunksize : int, default None + Maximum number of rows for RecordBatch chunks. Individual chunks may + be smaller depending on the chunk layout of individual columns. + """ + cdef: + # max_chunksize must be > 0 to have any impact + int64_t c_max_chunksize = -1 + + if 'chunksize' in kwargs: + max_chunksize = kwargs['chunksize'] + msg = ('The parameter chunksize is deprecated for the write_table ' + 'methods as of 0.15, please use parameter ' + 'max_chunksize instead') + warnings.warn(msg, FutureWarning) + + if max_chunksize is not None: + c_max_chunksize = max_chunksize + + with nogil: + check_flight_status( + self._writer().WriteTable(table.table[0], c_max_chunksize)) + + def close(self): + """ + Close stream and write end-of-stream 0 marker. + """ + with nogil: + check_flight_status(self._writer().Close()) + + def write_with_metadata(self, RecordBatch batch, buf): + """Write a RecordBatch along with Flight metadata. + + Parameters + ---------- + batch : RecordBatch + The next RecordBatch in the stream. + buf : Buffer + Application-specific metadata for the batch as defined by + Flight. + """ + cdef shared_ptr[CBuffer] c_buf = pyarrow_unwrap_buffer(as_buffer(buf)) + with nogil: + check_flight_status( + self._writer().WriteWithMetadata(deref(batch.batch), c_buf)) + + +cdef class FlightStreamWriter(MetadataRecordBatchWriter): + """A writer that also allows closing the write side of a stream.""" + + def done_writing(self): + """Indicate that the client is done writing, but not done reading.""" + with nogil: + check_flight_status( + ( self.writer.get()).DoneWriting()) + + +cdef class FlightMetadataReader(_Weakrefable): + """A reader for Flight metadata messages sent during a DoPut.""" + + cdef: + unique_ptr[CFlightMetadataReader] reader + + def read(self): + """Read the next metadata message.""" + cdef shared_ptr[CBuffer] buf + with nogil: + check_flight_status(self.reader.get().ReadMetadata(&buf)) + if buf == NULL: + return None + return pyarrow_wrap_buffer(buf) + + +cdef class FlightMetadataWriter(_Weakrefable): + """A sender for Flight metadata messages during a DoPut.""" + + cdef: + unique_ptr[CFlightMetadataWriter] writer + + def write(self, message): + """Write the next metadata message. + + Parameters + ---------- + message : Buffer + """ + cdef shared_ptr[CBuffer] buf = \ + pyarrow_unwrap_buffer(as_buffer(message)) + with nogil: + check_flight_status(self.writer.get().WriteMetadata(deref(buf))) + + +class AsyncioCall: + """State for an async RPC using asyncio.""" + + def __init__(self) -> None: + import asyncio + self._future = asyncio.get_running_loop().create_future() + + def as_awaitable(self) -> object: + return self._future + + def wakeup(self, result_or_exception) -> None: + # Mark the Future done from within its loop (asyncio + # objects are generally not thread-safe) + loop = self._future.get_loop() + if isinstance(result_or_exception, BaseException): + loop.call_soon_threadsafe( + self._future.set_exception, result_or_exception) + else: + loop.call_soon_threadsafe( + self._future.set_result, result_or_exception) + + +cdef class AsyncioFlightClient: + """ + A FlightClient with an asyncio-based async interface. + + This interface is EXPERIMENTAL. + """ + + cdef: + FlightClient _client + + def __init__(self, FlightClient client) -> None: + self._client = client + + async def get_flight_info( + self, + descriptor: FlightDescriptor, + *, + options: FlightCallOptions = None, + ): + call = AsyncioCall() + self._get_flight_info(call, descriptor, options) + return await call.as_awaitable() + + cdef _get_flight_info(self, call, descriptor, options): + cdef: + CFlightCallOptions* c_options = \ + FlightCallOptions.unwrap(options) + CFlightDescriptor c_descriptor = \ + FlightDescriptor.unwrap(descriptor) + CFuture[CFlightInfo] c_future + + with nogil: + c_future = self._client.client.get().GetFlightInfoAsync( + deref(c_options), c_descriptor) + + BindFuture(move(c_future), call.wakeup, FlightInfo.wrap) + + +cdef class FlightClient(_Weakrefable): + """A client to a Flight service. + + Connect to a Flight service on the given host and port. + + Parameters + ---------- + location : str, tuple or Location + Location to connect to. Either a gRPC URI like `grpc://localhost:port`, + a tuple of (host, port) pair, or a Location instance. + tls_root_certs : bytes or None + PEM-encoded + cert_chain: bytes or None + Client certificate if using mutual TLS + private_key: bytes or None + Client private key for cert_chain is using mutual TLS + override_hostname : str or None + Override the hostname checked by TLS. Insecure, use with caution. + middleware : list optional, default None + A list of ClientMiddlewareFactory instances. + write_size_limit_bytes : int optional, default None + A soft limit on the size of a data payload sent to the + server. Enabled if positive. If enabled, writing a record + batch that (when serialized) exceeds this limit will raise an + exception; the client can retry the write with a smaller + batch. + disable_server_verification : boolean optional, default False + A flag that indicates that, if the client is connecting + with TLS, that it skips server verification. If this is + enabled, all other TLS settings are overridden. + generic_options : list optional, default None + A list of generic (string, int or string) option tuples passed + to the underlying transport. Effect is implementation + dependent. + """ + cdef: + unique_ptr[CFlightClient] client + + def __init__(self, location, *, tls_root_certs=None, cert_chain=None, + private_key=None, override_hostname=None, middleware=None, + write_size_limit_bytes=None, + disable_server_verification=None, generic_options=None): + if isinstance(location, (bytes, str)): + location = Location(location) + elif isinstance(location, tuple): + host, port = location + if tls_root_certs or disable_server_verification is not None: + location = Location.for_grpc_tls(host, port) + else: + location = Location.for_grpc_tcp(host, port) + elif not isinstance(location, Location): + raise TypeError('`location` argument must be a string, tuple or a ' + 'Location instance') + self.init(location, tls_root_certs, cert_chain, private_key, + override_hostname, middleware, write_size_limit_bytes, + disable_server_verification, generic_options) + + cdef init(self, Location location, tls_root_certs, cert_chain, + private_key, override_hostname, middleware, + write_size_limit_bytes, disable_server_verification, + generic_options): + cdef: + CLocation c_location = Location.unwrap(location) + CFlightClientOptions c_options = CFlightClientOptions.Defaults() + function[cb_client_middleware_start_call] start_call = \ + &_client_middleware_start_call + CIntStringVariant variant + + if tls_root_certs: + c_options.tls_root_certs = tobytes(tls_root_certs) + if cert_chain: + c_options.cert_chain = tobytes(cert_chain) + if private_key: + c_options.private_key = tobytes(private_key) + if override_hostname: + c_options.override_hostname = tobytes(override_hostname) + if disable_server_verification is not None: + c_options.disable_server_verification = disable_server_verification + if middleware: + for factory in middleware: + c_options.middleware.push_back( + + make_shared[CPyClientMiddlewareFactory]( + factory, start_call)) + if write_size_limit_bytes is not None: + c_options.write_size_limit_bytes = write_size_limit_bytes + else: + c_options.write_size_limit_bytes = 0 + if generic_options: + for key, value in generic_options: + if isinstance(value, (str, bytes)): + variant = CIntStringVariant( tobytes(value)) + else: + variant = CIntStringVariant( value) + c_options.generic_options.push_back( + pair[c_string, CIntStringVariant](tobytes(key), variant)) + + with nogil: + check_flight_status(CFlightClient.Connect(c_location, c_options + ).Value(&self.client)) + + @property + def supports_async(self): + return self.client.get().supports_async() + + def as_async(self) -> None: + check_status(self.client.get().CheckAsyncSupport()) + return AsyncioFlightClient(self) + + def wait_for_available(self, timeout=5): + """Block until the server can be contacted. + + Parameters + ---------- + timeout : int, default 5 + The maximum seconds to wait. + """ + deadline = time.time() + timeout + while True: + try: + list(self.list_flights()) + except FlightUnavailableError: + if time.time() < deadline: + time.sleep(0.025) + continue + else: + raise + except NotImplementedError: + # allow if list_flights is not implemented, because + # the server can be contacted nonetheless + break + else: + break + + @classmethod + def connect(cls, location, tls_root_certs=None, cert_chain=None, + private_key=None, override_hostname=None, + disable_server_verification=None): + """Connect to a Flight server. + + .. deprecated:: 0.15.0 + Use the ``FlightClient`` constructor or ``pyarrow.flight.connect`` function instead. + """ + warnings.warn("The 'FlightClient.connect' method is deprecated, use " + "FlightClient constructor or pyarrow.flight.connect " + "function instead") + return FlightClient( + location, tls_root_certs=tls_root_certs, + cert_chain=cert_chain, private_key=private_key, + override_hostname=override_hostname, + disable_server_verification=disable_server_verification + ) + + def authenticate(self, auth_handler, options: FlightCallOptions = None): + """Authenticate to the server. + + Parameters + ---------- + auth_handler : ClientAuthHandler + The authentication mechanism to use. + options : FlightCallOptions + Options for this call. + """ + cdef: + unique_ptr[CClientAuthHandler] handler + CFlightCallOptions* c_options = FlightCallOptions.unwrap(options) + + if not isinstance(auth_handler, ClientAuthHandler): + raise TypeError( + "FlightClient.authenticate takes a ClientAuthHandler, " + "not '{}'".format(type(auth_handler))) + handler.reset(( auth_handler).to_handler()) + with nogil: + check_flight_status( + self.client.get().Authenticate(deref(c_options), + move(handler))) + + def authenticate_basic_token(self, username, password, + options: FlightCallOptions = None): + """Authenticate to the server with HTTP basic authentication. + + Parameters + ---------- + username : string + Username to authenticate with + password : string + Password to authenticate with + options : FlightCallOptions + Options for this call + + Returns + ------- + tuple : Tuple[str, str] + A tuple representing the FlightCallOptions authorization + header entry of a bearer token. + """ + cdef: + CResult[pair[c_string, c_string]] result + CFlightCallOptions* c_options = FlightCallOptions.unwrap(options) + c_string user = tobytes(username) + c_string pw = tobytes(password) + + with nogil: + result = self.client.get().AuthenticateBasicToken(deref(c_options), + user, pw) + check_flight_status(result.status()) + + return GetResultValue(result) + + def list_actions(self, options: FlightCallOptions = None): + """List the actions available on a service.""" + cdef: + vector[CActionType] results + CFlightCallOptions* c_options = FlightCallOptions.unwrap(options) + + with SignalStopHandler() as stop_handler: + c_options.stop_token = \ + ( stop_handler.stop_token).stop_token + with nogil: + check_flight_status( + self.client.get().ListActions(deref(c_options)).Value(&results)) + + result = [] + for action_type in results: + py_action = ActionType(frombytes(action_type.type), + frombytes(action_type.description)) + result.append(py_action) + + return result + + def do_action(self, action, options: FlightCallOptions = None): + """ + Execute an action on a service. + + Parameters + ---------- + action : str, tuple, or Action + Can be action type name (no body), type and body, or any Action + object + options : FlightCallOptions + RPC options + + Returns + ------- + results : iterator of Result values + """ + cdef: + unique_ptr[CResultStream] results + CFlightCallOptions* c_options = FlightCallOptions.unwrap(options) + + if isinstance(action, (str, bytes)): + action = Action(action, b'') + elif isinstance(action, tuple): + action = Action(*action) + elif not isinstance(action, Action): + raise TypeError("Action must be Action instance, string, or tuple") + + cdef CAction c_action = Action.unwrap( action) + with nogil: + check_flight_status( + self.client.get().DoAction( + deref(c_options), c_action).Value(&results)) + + def _do_action_response(): + cdef: + Result result + while True: + result = Result.__new__(Result) + with nogil: + check_flight_status(results.get().Next().Value(&result.result)) + if result.result == NULL: + break + yield result + return _do_action_response() + + def list_flights(self, criteria: bytes = None, + options: FlightCallOptions = None): + """List the flights available on a service.""" + cdef: + unique_ptr[CFlightListing] listing + FlightInfo result + CFlightCallOptions* c_options = FlightCallOptions.unwrap(options) + CCriteria c_criteria + + if criteria: + c_criteria.expression = tobytes(criteria) + + with SignalStopHandler() as stop_handler: + c_options.stop_token = \ + ( stop_handler.stop_token).stop_token + with nogil: + check_flight_status( + self.client.get().ListFlights(deref(c_options), + c_criteria).Value(&listing)) + + while True: + result = FlightInfo.__new__(FlightInfo) + with nogil: + check_flight_status(listing.get().Next().Value(&result.info)) + if result.info == NULL: + break + yield result + + def get_flight_info(self, descriptor: FlightDescriptor, + options: FlightCallOptions = None): + """Request information about an available flight.""" + cdef: + FlightInfo result = FlightInfo.__new__(FlightInfo) + CFlightCallOptions* c_options = FlightCallOptions.unwrap(options) + CFlightDescriptor c_descriptor = \ + FlightDescriptor.unwrap(descriptor) + + with nogil: + check_flight_status(self.client.get().GetFlightInfo( + deref(c_options), c_descriptor).Value(&result.info)) + + return result + + def get_schema(self, descriptor: FlightDescriptor, + options: FlightCallOptions = None): + """Request schema for an available flight.""" + cdef: + SchemaResult result = SchemaResult.__new__(SchemaResult) + CFlightCallOptions* c_options = FlightCallOptions.unwrap(options) + CFlightDescriptor c_descriptor = \ + FlightDescriptor.unwrap(descriptor) + with nogil: + check_status( + self.client.get() + .GetSchema(deref(c_options), c_descriptor).Value(&result.result) + ) + + return result + + def do_get(self, ticket: Ticket, options: FlightCallOptions = None): + """Request the data for a flight. + + Returns + ------- + reader : FlightStreamReader + """ + cdef: + unique_ptr[CFlightStreamReader] reader + CFlightCallOptions* c_options = FlightCallOptions.unwrap(options) + + with nogil: + check_flight_status( + self.client.get().DoGet( + deref(c_options), ticket.c_ticket).Value(&reader)) + result = FlightStreamReader() + result.reader.reset(reader.release()) + return result + + def do_put(self, descriptor: FlightDescriptor, Schema schema not None, + options: FlightCallOptions = None): + """Upload data to a flight. + + Returns + ------- + writer : FlightStreamWriter + reader : FlightMetadataReader + """ + cdef: + shared_ptr[CSchema] c_schema = pyarrow_unwrap_schema(schema) + CDoPutResult c_do_put_result + CFlightCallOptions* c_options = FlightCallOptions.unwrap(options) + CFlightDescriptor c_descriptor = \ + FlightDescriptor.unwrap(descriptor) + + with nogil: + check_flight_status(self.client.get().DoPut( + deref(c_options), + c_descriptor, + c_schema).Value(&c_do_put_result)) + py_writer = FlightStreamWriter() + py_writer.writer.reset(c_do_put_result.writer.release()) + py_reader = FlightMetadataReader() + py_reader.reader.reset(c_do_put_result.reader.release()) + return py_writer, py_reader + + def do_exchange(self, descriptor: FlightDescriptor, + options: FlightCallOptions = None): + """Start a bidirectional data exchange with a server. + + Parameters + ---------- + descriptor : FlightDescriptor + A descriptor for the flight. + options : FlightCallOptions + RPC options. + + Returns + ------- + writer : FlightStreamWriter + reader : FlightStreamReader + """ + cdef: + CDoExchangeResult c_do_exchange_result + CFlightCallOptions* c_options = FlightCallOptions.unwrap(options) + CFlightDescriptor c_descriptor = \ + FlightDescriptor.unwrap(descriptor) + + with nogil: + check_flight_status(self.client.get().DoExchange( + deref(c_options), + c_descriptor).Value(&c_do_exchange_result)) + py_writer = FlightStreamWriter() + py_writer.writer.reset(c_do_exchange_result.writer.release()) + py_reader = FlightStreamReader() + py_reader.reader.reset(c_do_exchange_result.reader.release()) + return py_writer, py_reader + + def close(self): + """Close the client and disconnect.""" + client = self.client.get() + if client != NULL: + check_flight_status(client.Close()) + + def __del__(self): + # Not ideal, but close() wasn't originally present so + # applications may not be calling it + self.close() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.close() + + +cdef class FlightDataStream(_Weakrefable): + """ + Abstract base class for Flight data streams. + + See Also + -------- + RecordBatchStream + GeneratorStream + """ + + cdef CFlightDataStream* to_stream(self) except *: + """Create the C++ data stream for the backing Python object. + + We don't expose the C++ object to Python, so we can manage its + lifetime from the Cython/C++ side. + """ + raise NotImplementedError + + +cdef class RecordBatchStream(FlightDataStream): + """A Flight data stream backed by RecordBatches. + + The remainder of this DoGet request will be handled in C++, + without having to acquire the GIL. + + """ + cdef: + object data_source + CIpcWriteOptions write_options + + def __init__(self, data_source, options=None): + """Create a RecordBatchStream from a data source. + + Parameters + ---------- + data_source : RecordBatchReader or Table + The data to stream to the client. + options : pyarrow.ipc.IpcWriteOptions, optional + Optional IPC options to control how to write the data. + """ + if (not isinstance(data_source, RecordBatchReader) and + not isinstance(data_source, lib.Table)): + raise TypeError("Expected RecordBatchReader or Table, " + "but got: {}".format(type(data_source))) + self.data_source = data_source + self.write_options = _get_options(options).c_options + + cdef CFlightDataStream* to_stream(self) except *: + cdef: + shared_ptr[CRecordBatchReader] reader + if isinstance(self.data_source, RecordBatchReader): + reader = ( self.data_source).reader + elif isinstance(self.data_source, lib.Table): + table = ( self.data_source).table + reader.reset(new TableBatchReader(deref(table))) + else: + raise RuntimeError("Can't construct RecordBatchStream " + "from type {}".format(type(self.data_source))) + return new CRecordBatchStream(reader, self.write_options) + + +cdef class GeneratorStream(FlightDataStream): + """A Flight data stream backed by a Python generator.""" + cdef: + shared_ptr[CSchema] schema + object generator + # A substream currently being consumed by the client, if + # present. Produced by the generator. + unique_ptr[CFlightDataStream] current_stream + CIpcWriteOptions c_options + + def __init__(self, schema, generator, options=None): + """Create a GeneratorStream from a Python generator. + + Parameters + ---------- + schema : Schema + The schema for the data to be returned. + + generator : iterator or iterable + The generator should yield other FlightDataStream objects, + Tables, RecordBatches, or RecordBatchReaders. + + options : pyarrow.ipc.IpcWriteOptions, optional + """ + self.schema = pyarrow_unwrap_schema(schema) + self.generator = iter(generator) + self.c_options = _get_options(options).c_options + + cdef CFlightDataStream* to_stream(self) except *: + cdef: + function[cb_data_stream_next] callback = &_data_stream_next + return new CPyGeneratorFlightDataStream(self, self.schema, callback, + self.c_options) + + +cdef class ServerCallContext(_Weakrefable): + """Per-call state/context.""" + cdef: + const CServerCallContext* context + + def peer_identity(self): + """Get the identity of the authenticated peer. + + May be the empty string. + """ + return tobytes(self.context.peer_identity()) + + def peer(self): + """Get the address of the peer.""" + # Set safe=True as gRPC on Windows sometimes gives garbage bytes + return frombytes(self.context.peer(), safe=True) + + def is_cancelled(self): + """Check if the current RPC call has been canceled by the client.""" + return self.context.is_cancelled() + + def add_header(self, key, value): + """Add a response header.""" + self.context.AddHeader(tobytes(key), tobytes(value)) + + def add_trailer(self, key, value): + """Add a response trailer.""" + self.context.AddTrailer(tobytes(key), tobytes(value)) + + def get_middleware(self, key): + """ + Get a middleware instance by key. + + Returns None if the middleware was not found. + """ + cdef: + CServerMiddleware* c_middleware = \ + self.context.GetMiddleware(CPyServerMiddlewareName) + CPyServerMiddleware* middleware + vector[CTracingServerMiddlewareTraceKey] c_trace_context + if c_middleware == NULL: + c_middleware = self.context.GetMiddleware(tobytes(key)) + + if c_middleware == NULL: + return None + elif c_middleware.name() == CPyServerMiddlewareName: + middleware = c_middleware + py_middleware = <_ServerMiddlewareWrapper> middleware.py_object() + return py_middleware.middleware.get(key) + elif c_middleware.name() == CTracingServerMiddlewareName: + c_trace_context = ( c_middleware + ).GetTraceContext() + trace_context = {pair.key: pair.value for pair in c_trace_context} + return TracingServerMiddleware(trace_context) + return None + + @staticmethod + cdef ServerCallContext wrap(const CServerCallContext& context): + cdef ServerCallContext result = \ + ServerCallContext.__new__(ServerCallContext) + result.context = &context + return result + + +cdef class ServerAuthReader(_Weakrefable): + """A reader for messages from the client during an auth handshake.""" + cdef: + CServerAuthReader* reader + + def read(self): + cdef c_string token + if not self.reader: + raise ValueError("Cannot use ServerAuthReader outside " + "ServerAuthHandler.authenticate") + with nogil: + check_flight_status(self.reader.Read(&token)) + return token + + cdef void poison(self): + """Prevent further usage of this object. + + This object is constructed by taking a pointer to a reference, + so we want to make sure Python users do not access this after + the reference goes away. + """ + self.reader = NULL + + @staticmethod + cdef ServerAuthReader wrap(CServerAuthReader* reader): + cdef ServerAuthReader result = \ + ServerAuthReader.__new__(ServerAuthReader) + result.reader = reader + return result + + +cdef class ServerAuthSender(_Weakrefable): + """A writer for messages to the client during an auth handshake.""" + cdef: + CServerAuthSender* sender + + def write(self, message): + cdef c_string c_message = tobytes(message) + if not self.sender: + raise ValueError("Cannot use ServerAuthSender outside " + "ServerAuthHandler.authenticate") + with nogil: + check_flight_status(self.sender.Write(c_message)) + + cdef void poison(self): + """Prevent further usage of this object. + + This object is constructed by taking a pointer to a reference, + so we want to make sure Python users do not access this after + the reference goes away. + """ + self.sender = NULL + + @staticmethod + cdef ServerAuthSender wrap(CServerAuthSender* sender): + cdef ServerAuthSender result = \ + ServerAuthSender.__new__(ServerAuthSender) + result.sender = sender + return result + + +cdef class ClientAuthReader(_Weakrefable): + """A reader for messages from the server during an auth handshake.""" + cdef: + CClientAuthReader* reader + + def read(self): + cdef c_string token + if not self.reader: + raise ValueError("Cannot use ClientAuthReader outside " + "ClientAuthHandler.authenticate") + with nogil: + check_flight_status(self.reader.Read(&token)) + return token + + cdef void poison(self): + """Prevent further usage of this object. + + This object is constructed by taking a pointer to a reference, + so we want to make sure Python users do not access this after + the reference goes away. + """ + self.reader = NULL + + @staticmethod + cdef ClientAuthReader wrap(CClientAuthReader* reader): + cdef ClientAuthReader result = \ + ClientAuthReader.__new__(ClientAuthReader) + result.reader = reader + return result + + +cdef class ClientAuthSender(_Weakrefable): + """A writer for messages to the server during an auth handshake.""" + cdef: + CClientAuthSender* sender + + def write(self, message): + cdef c_string c_message = tobytes(message) + if not self.sender: + raise ValueError("Cannot use ClientAuthSender outside " + "ClientAuthHandler.authenticate") + with nogil: + check_flight_status(self.sender.Write(c_message)) + + cdef void poison(self): + """Prevent further usage of this object. + + This object is constructed by taking a pointer to a reference, + so we want to make sure Python users do not access this after + the reference goes away. + """ + self.sender = NULL + + @staticmethod + cdef ClientAuthSender wrap(CClientAuthSender* sender): + cdef ClientAuthSender result = \ + ClientAuthSender.__new__(ClientAuthSender) + result.sender = sender + return result + + +cdef CStatus _data_stream_next(void* self, CFlightPayload* payload) except *: + """Callback for implementing FlightDataStream in Python.""" + cdef: + unique_ptr[CFlightDataStream] data_stream + + py_stream = self + if not isinstance(py_stream, GeneratorStream): + raise RuntimeError("self object in callback is not GeneratorStream") + stream = py_stream + + # The generator is allowed to yield a reader or table which we + # yield from; if that sub-generator is empty, we need to reset and + # try again. However, limit the number of attempts so that we + # don't just spin forever. + max_attempts = 128 + for _ in range(max_attempts): + if stream.current_stream != nullptr: + with nogil: + check_flight_status( + stream.current_stream.get().Next().Value(payload)) + # If the stream ended, see if there's another stream from the + # generator + if payload.ipc_message.metadata != nullptr: + return CStatus_OK() + stream.current_stream.reset(nullptr) + + try: + result = next(stream.generator) + except StopIteration: + payload.ipc_message.metadata.reset( nullptr) + return CStatus_OK() + except FlightError as flight_error: + return ( flight_error).to_status() + + if isinstance(result, (list, tuple)): + result, metadata = result + else: + result, metadata = result, None + + if isinstance(result, (Table, RecordBatchReader)): + if metadata: + raise ValueError("Can only return metadata alongside a " + "RecordBatch.") + result = RecordBatchStream(result) + + stream_schema = pyarrow_wrap_schema(stream.schema) + if isinstance(result, FlightDataStream): + if metadata: + raise ValueError("Can only return metadata alongside a " + "RecordBatch.") + data_stream = unique_ptr[CFlightDataStream]( + ( result).to_stream()) + substream_schema = pyarrow_wrap_schema(data_stream.get().schema()) + if substream_schema != stream_schema: + raise ValueError("Got a FlightDataStream whose schema " + "does not match the declared schema of this " + "GeneratorStream. " + "Got: {}\nExpected: {}".format( + substream_schema, stream_schema)) + stream.current_stream.reset( + new CPyFlightDataStream(result, move(data_stream))) + # Loop around and try again + continue + elif isinstance(result, RecordBatch): + batch = result + if batch.schema != stream_schema: + raise ValueError("Got a RecordBatch whose schema does not " + "match the declared schema of this " + "GeneratorStream. " + "Got: {}\nExpected: {}".format(batch.schema, + stream_schema)) + check_flight_status(GetRecordBatchPayload( + deref(batch.batch), + stream.c_options, + &payload.ipc_message)) + if metadata: + payload.app_metadata = pyarrow_unwrap_buffer( + as_buffer(metadata)) + else: + raise TypeError("GeneratorStream must be initialized with " + "an iterator of FlightDataStream, Table, " + "RecordBatch, or RecordBatchStreamReader objects, " + "not {}.".format(type(result))) + # Don't loop around + return CStatus_OK() + # Ran out of attempts (the RPC handler kept yielding empty tables/readers) + raise RuntimeError("While getting next payload, ran out of attempts to " + "get something to send " + "(application server implementation error)") + + +cdef CStatus _list_flights(void* self, const CServerCallContext& context, + const CCriteria* c_criteria, + unique_ptr[CFlightListing]* listing) except *: + """Callback for implementing ListFlights in Python.""" + cdef: + vector[CFlightInfo] flights + + try: + result = ( self).list_flights(ServerCallContext.wrap(context), + c_criteria.expression) + for info in result: + if not isinstance(info, FlightInfo): + raise TypeError("FlightServerBase.list_flights must return " + "FlightInfo instances, but got {}".format( + type(info))) + flights.push_back(deref(( info).info.get())) + listing.reset(new CSimpleFlightListing(flights)) + except FlightError as flight_error: + return ( flight_error).to_status() + return CStatus_OK() + + +cdef CStatus _get_flight_info(void* self, const CServerCallContext& context, + CFlightDescriptor c_descriptor, + unique_ptr[CFlightInfo]* info) except *: + """Callback for implementing Flight servers in Python.""" + cdef: + FlightDescriptor py_descriptor = \ + FlightDescriptor.__new__(FlightDescriptor) + py_descriptor.descriptor = c_descriptor + try: + result = ( self).get_flight_info( + ServerCallContext.wrap(context), + py_descriptor) + except FlightError as flight_error: + return ( flight_error).to_status() + if not isinstance(result, FlightInfo): + raise TypeError("FlightServerBase.get_flight_info must return " + "a FlightInfo instance, but got {}".format( + type(result))) + info.reset(new CFlightInfo(deref(( result).info.get()))) + return CStatus_OK() + +cdef CStatus _get_schema(void* self, const CServerCallContext& context, + CFlightDescriptor c_descriptor, + unique_ptr[CSchemaResult]* info) except *: + """Callback for implementing Flight servers in Python.""" + cdef: + FlightDescriptor py_descriptor = \ + FlightDescriptor.__new__(FlightDescriptor) + py_descriptor.descriptor = c_descriptor + result = ( self).get_schema(ServerCallContext.wrap(context), + py_descriptor) + if not isinstance(result, SchemaResult): + raise TypeError("FlightServerBase.get_schema_info must return " + "a SchemaResult instance, but got {}".format( + type(result))) + info.reset(new CSchemaResult(deref(( result).result.get()))) + return CStatus_OK() + +cdef CStatus _do_put(void* self, const CServerCallContext& context, + unique_ptr[CFlightMessageReader] reader, + unique_ptr[CFlightMetadataWriter] writer) except *: + """Callback for implementing Flight servers in Python.""" + cdef: + MetadataRecordBatchReader py_reader = MetadataRecordBatchReader() + FlightMetadataWriter py_writer = FlightMetadataWriter() + FlightDescriptor descriptor = \ + FlightDescriptor.__new__(FlightDescriptor) + + descriptor.descriptor = reader.get().descriptor() + py_reader.reader.reset(reader.release()) + py_writer.writer.reset(writer.release()) + try: + ( self).do_put(ServerCallContext.wrap(context), descriptor, + py_reader, py_writer) + return CStatus_OK() + except FlightError as flight_error: + return ( flight_error).to_status() + + +cdef CStatus _do_get(void* self, const CServerCallContext& context, + CTicket ticket, + unique_ptr[CFlightDataStream]* stream) except *: + """Callback for implementing Flight servers in Python.""" + cdef: + unique_ptr[CFlightDataStream] data_stream + + py_ticket = Ticket(ticket.ticket) + try: + result = ( self).do_get(ServerCallContext.wrap(context), + py_ticket) + except FlightError as flight_error: + return ( flight_error).to_status() + if not isinstance(result, FlightDataStream): + raise TypeError("FlightServerBase.do_get must return " + "a FlightDataStream") + data_stream = unique_ptr[CFlightDataStream]( + ( result).to_stream()) + stream[0] = unique_ptr[CFlightDataStream]( + new CPyFlightDataStream(result, move(data_stream))) + return CStatus_OK() + + +cdef CStatus _do_exchange(void* self, const CServerCallContext& context, + unique_ptr[CFlightMessageReader] reader, + unique_ptr[CFlightMessageWriter] writer) except *: + """Callback for implementing Flight servers in Python.""" + cdef: + MetadataRecordBatchReader py_reader = MetadataRecordBatchReader() + MetadataRecordBatchWriter py_writer = MetadataRecordBatchWriter() + FlightDescriptor descriptor = \ + FlightDescriptor.__new__(FlightDescriptor) + + descriptor.descriptor = reader.get().descriptor() + py_reader.reader.reset(reader.release()) + py_writer.writer.reset(writer.release()) + try: + ( self).do_exchange(ServerCallContext.wrap(context), + descriptor, py_reader, py_writer) + return CStatus_OK() + except FlightError as flight_error: + return ( flight_error).to_status() + + +cdef CStatus _do_action_result_next( + void* self, + unique_ptr[CFlightResult]* result +) except *: + """Callback for implementing Flight servers in Python.""" + cdef: + CFlightResult* c_result + + try: + action_result = next( self) + if not isinstance(action_result, Result): + action_result = Result(action_result) + c_result = ( action_result).result.get() + result.reset(new CFlightResult(deref(c_result))) + except StopIteration: + result.reset(nullptr) + except FlightError as flight_error: + return ( flight_error).to_status() + return CStatus_OK() + + +cdef CStatus _do_action(void* self, const CServerCallContext& context, + const CAction& action, + unique_ptr[CResultStream]* result) except *: + """Callback for implementing Flight servers in Python.""" + cdef: + function[cb_result_next] ptr = &_do_action_result_next + py_action = Action(action.type, pyarrow_wrap_buffer(action.body)) + try: + responses = ( self).do_action(ServerCallContext.wrap(context), + py_action) + except FlightError as flight_error: + return ( flight_error).to_status() + # Let the application return an iterator or anything convertible + # into one + if responses is None: + # Server didn't return anything + responses = [] + result.reset(new CPyFlightResultStream(iter(responses), ptr)) + return CStatus_OK() + + +cdef CStatus _list_actions(void* self, const CServerCallContext& context, + vector[CActionType]* actions) except *: + """Callback for implementing Flight servers in Python.""" + cdef: + CActionType action_type + # Method should return a list of ActionTypes or similar tuple + try: + result = ( self).list_actions(ServerCallContext.wrap(context)) + for action in result: + if not isinstance(action, tuple): + raise TypeError( + "Results of list_actions must be ActionType or tuple") + action_type.type = tobytes(action[0]) + action_type.description = tobytes(action[1]) + actions.push_back(action_type) + except FlightError as flight_error: + return ( flight_error).to_status() + return CStatus_OK() + + +cdef CStatus _server_authenticate(void* self, CServerAuthSender* outgoing, + CServerAuthReader* incoming) except *: + """Callback for implementing authentication in Python.""" + sender = ServerAuthSender.wrap(outgoing) + reader = ServerAuthReader.wrap(incoming) + try: + ( self).authenticate(sender, reader) + except FlightError as flight_error: + return ( flight_error).to_status() + finally: + sender.poison() + reader.poison() + return CStatus_OK() + +cdef CStatus _is_valid(void* self, const c_string& token, + c_string* peer_identity) except *: + """Callback for implementing authentication in Python.""" + cdef c_string c_result + try: + c_result = tobytes(( self).is_valid(token)) + peer_identity[0] = c_result + except FlightError as flight_error: + return ( flight_error).to_status() + return CStatus_OK() + + +cdef CStatus _client_authenticate(void* self, CClientAuthSender* outgoing, + CClientAuthReader* incoming) except *: + """Callback for implementing authentication in Python.""" + sender = ClientAuthSender.wrap(outgoing) + reader = ClientAuthReader.wrap(incoming) + try: + ( self).authenticate(sender, reader) + except FlightError as flight_error: + return ( flight_error).to_status() + finally: + sender.poison() + reader.poison() + return CStatus_OK() + + +cdef CStatus _get_token(void* self, c_string* token) except *: + """Callback for implementing authentication in Python.""" + cdef c_string c_result + try: + c_result = tobytes(( self).get_token()) + token[0] = c_result + except FlightError as flight_error: + return ( flight_error).to_status() + return CStatus_OK() + + +cdef CStatus _middleware_sending_headers( + void* self, CAddCallHeaders* add_headers) except *: + """Callback for implementing middleware.""" + try: + headers = ( self).sending_headers() + except FlightError as flight_error: + return ( flight_error).to_status() + + if headers: + for header, values in headers.items(): + if isinstance(values, (str, bytes)): + values = (values,) + # Headers in gRPC (and HTTP/1, HTTP/2) are required to be + # valid, lowercase ASCII. + header = header.lower() + if isinstance(header, str): + header = header.encode("ascii") + for value in values: + if isinstance(value, str): + value = value.encode("ascii") + # Allow bytes values to pass through. + add_headers.AddHeader(header, value) + + return CStatus_OK() + + +cdef CStatus _middleware_call_completed( + void* self, + const CStatus& call_status) except *: + """Callback for implementing middleware.""" + try: + try: + check_flight_status(call_status) + except Exception as e: + ( self).call_completed(e) + else: + ( self).call_completed(None) + except FlightError as flight_error: + return ( flight_error).to_status() + return CStatus_OK() + + +cdef CStatus _middleware_received_headers( + void* self, + const CCallHeaders& c_headers) except *: + """Callback for implementing middleware.""" + try: + headers = convert_headers(c_headers) + ( self).received_headers(headers) + except FlightError as flight_error: + return ( flight_error).to_status() + return CStatus_OK() + + +cdef dict convert_headers(const CCallHeaders& c_headers): + cdef: + CCallHeaders.const_iterator header_iter = c_headers.cbegin() + headers = {} + while header_iter != c_headers.cend(): + header = c_string(deref(header_iter).first).decode("ascii") + value = c_string(deref(header_iter).second) + if not header.endswith("-bin"): + # Text header values in gRPC (and HTTP/1, HTTP/2) are + # required to be valid ASCII. Binary header values are + # exposed as bytes. + value = value.decode("ascii") + headers.setdefault(header, []).append(value) + postincrement(header_iter) + return headers + + +cdef CStatus _server_middleware_start_call( + void* self, + const CCallInfo& c_info, + const CCallHeaders& c_headers, + shared_ptr[CServerMiddleware]* c_instance) except *: + """Callback for implementing server middleware.""" + instance = None + try: + call_info = wrap_call_info(c_info) + headers = convert_headers(c_headers) + instance = ( self).start_call(call_info, headers) + except FlightError as flight_error: + return ( flight_error).to_status() + + if instance: + ServerMiddleware.wrap(instance, c_instance) + + return CStatus_OK() + + +cdef CStatus _client_middleware_start_call( + void* self, + const CCallInfo& c_info, + unique_ptr[CClientMiddleware]* c_instance) except *: + """Callback for implementing client middleware.""" + instance = None + try: + call_info = wrap_call_info(c_info) + instance = ( self).start_call(call_info) + except FlightError as flight_error: + return ( flight_error).to_status() + + if instance: + ClientMiddleware.wrap(instance, c_instance) + + return CStatus_OK() + + +cdef class ServerAuthHandler(_Weakrefable): + """Authentication middleware for a server. + + To implement an authentication mechanism, subclass this class and + override its methods. + + """ + + def authenticate(self, outgoing, incoming): + """Conduct the handshake with the client. + + May raise an error if the client cannot authenticate. + + Parameters + ---------- + outgoing : ServerAuthSender + A channel to send messages to the client. + incoming : ServerAuthReader + A channel to read messages from the client. + """ + raise NotImplementedError + + def is_valid(self, token): + """Validate a client token, returning their identity. + + May return an empty string (if the auth mechanism does not + name the peer) or raise an exception (if the token is + invalid). + + Parameters + ---------- + token : bytes + The authentication token from the client. + + """ + raise NotImplementedError + + cdef PyServerAuthHandler* to_handler(self): + cdef PyServerAuthHandlerVtable vtable + vtable.authenticate = _server_authenticate + vtable.is_valid = _is_valid + return new PyServerAuthHandler(self, vtable) + + +cdef class ClientAuthHandler(_Weakrefable): + """Authentication plugin for a client.""" + + def authenticate(self, outgoing, incoming): + """Conduct the handshake with the server. + + Parameters + ---------- + outgoing : ClientAuthSender + A channel to send messages to the server. + incoming : ClientAuthReader + A channel to read messages from the server. + """ + raise NotImplementedError + + def get_token(self): + """Get the auth token for a call.""" + raise NotImplementedError + + cdef PyClientAuthHandler* to_handler(self): + cdef PyClientAuthHandlerVtable vtable + vtable.authenticate = _client_authenticate + vtable.get_token = _get_token + return new PyClientAuthHandler(self, vtable) + + +_CallInfo = collections.namedtuple("_CallInfo", ["method"]) + + +class CallInfo(_CallInfo): + """Information about a particular RPC for Flight middleware.""" + + +cdef wrap_call_info(const CCallInfo& c_info): + method = wrap_flight_method(c_info.method) + return CallInfo(method=method) + + +cdef class ClientMiddlewareFactory(_Weakrefable): + """A factory for new middleware instances. + + All middleware methods will be called from the same thread as the + RPC method implementation. That is, thread-locals set in the + client are accessible from the middleware itself. + + """ + + def start_call(self, info): + """Called at the start of an RPC. + + This must be thread-safe and must not raise exceptions. + + Parameters + ---------- + info : CallInfo + Information about the call. + + Returns + ------- + instance : ClientMiddleware + An instance of ClientMiddleware (the instance to use for + the call), or None if this call is not intercepted. + + """ + + +cdef class ClientMiddleware(_Weakrefable): + """Client-side middleware for a call, instantiated per RPC. + + Methods here should be fast and must be infallible: they should + not raise exceptions or stall indefinitely. + + """ + + def sending_headers(self): + """A callback before headers are sent. + + Returns + ------- + headers : dict + A dictionary of header values to add to the request, or + None if no headers are to be added. The dictionary should + have string keys and string or list-of-string values. + + Bytes values are allowed, but the underlying transport may + not support them or may restrict them. For gRPC, binary + values are only allowed on headers ending in "-bin". + + Header names must be lowercase ASCII. + + """ + + def received_headers(self, headers): + """A callback when headers are received. + + The default implementation does nothing. + + Parameters + ---------- + headers : dict + A dictionary of headers from the server. Keys are strings + and values are lists of strings (for text headers) or + bytes (for binary headers). + + """ + + def call_completed(self, exception): + """A callback when the call finishes. + + The default implementation does nothing. + + Parameters + ---------- + exception : ArrowException + If the call errored, this is the equivalent + exception. Will be None if the call succeeded. + + """ + + @staticmethod + cdef void wrap(object py_middleware, + unique_ptr[CClientMiddleware]* c_instance): + cdef PyClientMiddlewareVtable vtable + vtable.sending_headers = _middleware_sending_headers + vtable.received_headers = _middleware_received_headers + vtable.call_completed = _middleware_call_completed + c_instance[0].reset(new CPyClientMiddleware(py_middleware, vtable)) + + +cdef class ServerMiddlewareFactory(_Weakrefable): + """A factory for new middleware instances. + + All middleware methods will be called from the same thread as the + RPC method implementation. That is, thread-locals set in the + middleware are accessible from the method itself. + + """ + + def start_call(self, info, headers): + """Called at the start of an RPC. + + This must be thread-safe. + + Parameters + ---------- + info : CallInfo + Information about the call. + headers : dict + A dictionary of headers from the client. Keys are strings + and values are lists of strings (for text headers) or + bytes (for binary headers). + + Returns + ------- + instance : ServerMiddleware + An instance of ServerMiddleware (the instance to use for + the call), or None if this call is not intercepted. + + Raises + ------ + exception : pyarrow.ArrowException + If an exception is raised, the call will be rejected with + the given error. + + """ + + +cdef class TracingServerMiddlewareFactory(ServerMiddlewareFactory): + """A factory for tracing middleware instances. + + This enables OpenTelemetry support in Arrow (if Arrow was compiled + with OpenTelemetry support enabled). A new span will be started on + each RPC call. The TracingServerMiddleware instance can then be + retrieved within an RPC handler to get the propagated context, + which can be used to start a new span on the Python side. + + Because the Python/C++ OpenTelemetry libraries do not + interoperate, spans on the C++ side are not directly visible to + the Python side and vice versa. + + """ + + +cdef class ServerMiddleware(_Weakrefable): + """Server-side middleware for a call, instantiated per RPC. + + Methods here should be fast and must be infallible: they should + not raise exceptions or stall indefinitely. + + """ + + def sending_headers(self): + """A callback before headers are sent. + + Returns + ------- + headers : dict + A dictionary of header values to add to the response, or + None if no headers are to be added. The dictionary should + have string keys and string or list-of-string values. + + Bytes values are allowed, but the underlying transport may + not support them or may restrict them. For gRPC, binary + values are only allowed on headers ending in "-bin". + + Header names must be lowercase ASCII. + + """ + + def call_completed(self, exception): + """A callback when the call finishes. + + Parameters + ---------- + exception : pyarrow.ArrowException + If the call errored, this is the equivalent + exception. Will be None if the call succeeded. + + """ + + @staticmethod + cdef void wrap(object py_middleware, + shared_ptr[CServerMiddleware]* c_instance): + cdef PyServerMiddlewareVtable vtable + vtable.sending_headers = _middleware_sending_headers + vtable.call_completed = _middleware_call_completed + c_instance[0].reset(new CPyServerMiddleware(py_middleware, vtable)) + + +class TracingServerMiddleware(ServerMiddleware): + __slots__ = ["trace_context"] + + def __init__(self, trace_context): + self.trace_context = trace_context + + +cdef class _ServerMiddlewareFactoryWrapper(ServerMiddlewareFactory): + """Wrapper to bundle server middleware into a single C++ one.""" + + cdef: + dict factories + + def __init__(self, dict factories): + self.factories = factories + + def start_call(self, info, headers): + instances = {} + for key, factory in self.factories.items(): + instance = factory.start_call(info, headers) + if instance: + # TODO: prevent duplicate keys + instances[key] = instance + if instances: + wrapper = _ServerMiddlewareWrapper(instances) + return wrapper + return None + + +cdef class _ServerMiddlewareWrapper(ServerMiddleware): + cdef: + dict middleware + + def __init__(self, dict middleware): + self.middleware = middleware + + def sending_headers(self): + headers = collections.defaultdict(list) + for instance in self.middleware.values(): + more_headers = instance.sending_headers() + if not more_headers: + continue + # Manually merge with existing headers (since headers are + # multi-valued) + for key, values in more_headers.items(): + # ARROW-16606 gRPC aborts given non-lowercase headers + key = key.lower() + if isinstance(values, (bytes, str)): + values = (values,) + headers[key].extend(values) + return headers + + def call_completed(self, exception): + for instance in self.middleware.values(): + instance.call_completed(exception) + + +cdef class _FlightServerFinalizer(_Weakrefable): + """ + A finalizer that shuts down the server on destruction. + + See ARROW-16597. If the server is still active at interpreter + exit, the process may segfault. + """ + + cdef: + shared_ptr[PyFlightServer] server + + def finalize(self): + cdef: + PyFlightServer* server = self.server.get() + CStatus status + if server == NULL: + return + try: + with nogil: + status = server.Shutdown() + if status.ok(): + status = server.Wait() + check_flight_status(status) + finally: + self.server.reset() + + +cdef class FlightServerBase(_Weakrefable): + """A Flight service definition. + + To start the server, create an instance of this class with an + appropriate location. The server will be running as soon as the + instance is created; it is not required to call :meth:`serve`. + + Override methods to define your Flight service. + + Parameters + ---------- + location : str, tuple or Location optional, default None + Location to serve on. Either a gRPC URI like `grpc://localhost:port`, + a tuple of (host, port) pair, or a Location instance. + If None is passed then the server will be started on localhost with a + system provided random port. + auth_handler : ServerAuthHandler optional, default None + An authentication mechanism to use. May be None. + tls_certificates : list optional, default None + A list of (certificate, key) pairs. + verify_client : boolean optional, default False + If True, then enable mutual TLS: require the client to present + a client certificate, and validate the certificate. + root_certificates : bytes optional, default None + If enabling mutual TLS, this specifies the PEM-encoded root + certificate used to validate client certificates. + middleware : dict optional, default None + A dictionary of :class:`ServerMiddlewareFactory` instances. The + string keys can be used to retrieve the middleware instance within + RPC handlers (see :meth:`ServerCallContext.get_middleware`). + + """ + + cdef: + shared_ptr[PyFlightServer] server + object finalizer + + def __init__(self, location=None, auth_handler=None, + tls_certificates=None, verify_client=None, + root_certificates=None, middleware=None): + self.finalizer = None + if isinstance(location, (bytes, str)): + location = Location(location) + elif isinstance(location, (tuple, type(None))): + if location is None: + location = ('localhost', 0) + host, port = location + if tls_certificates: + location = Location.for_grpc_tls(host, port) + else: + location = Location.for_grpc_tcp(host, port) + elif not isinstance(location, Location): + raise TypeError('`location` argument must be a string, tuple or a ' + 'Location instance') + self.init(location, auth_handler, tls_certificates, verify_client, + tobytes(root_certificates or b""), middleware) + + cdef init(self, Location location, ServerAuthHandler auth_handler, + list tls_certificates, c_bool verify_client, + bytes root_certificates, dict middleware): + cdef: + PyFlightServerVtable vtable = PyFlightServerVtable() + PyFlightServer* c_server + unique_ptr[CFlightServerOptions] c_options + CCertKeyPair c_cert + function[cb_server_middleware_start_call] start_call = \ + &_server_middleware_start_call + pair[c_string, shared_ptr[CServerMiddlewareFactory]] c_middleware + + c_options.reset(new CFlightServerOptions(Location.unwrap(location))) + # mTLS configuration + c_options.get().verify_client = verify_client + c_options.get().root_certificates = root_certificates + + if auth_handler: + if not isinstance(auth_handler, ServerAuthHandler): + raise TypeError("auth_handler must be a ServerAuthHandler, " + "not a '{}'".format(type(auth_handler))) + c_options.get().auth_handler.reset( + ( auth_handler).to_handler()) + + if tls_certificates: + for cert, key in tls_certificates: + c_cert.pem_cert = tobytes(cert) + c_cert.pem_key = tobytes(key) + c_options.get().tls_certificates.push_back(c_cert) + + if middleware: + non_tracing_middleware = {} + enable_tracing = None + for key, factory in middleware.items(): + if isinstance(factory, TracingServerMiddlewareFactory): + if enable_tracing is not None: + raise ValueError( + "Can only provide " + "TracingServerMiddlewareFactory once") + if tobytes(key) == CPyServerMiddlewareName: + raise ValueError(f"Middleware key cannot be {key}") + enable_tracing = key + else: + non_tracing_middleware[key] = factory + + if enable_tracing: + c_middleware.first = tobytes(enable_tracing) + c_middleware.second = MakeTracingServerMiddlewareFactory() + c_options.get().middleware.push_back(c_middleware) + + py_middleware = _ServerMiddlewareFactoryWrapper( + non_tracing_middleware) + c_middleware.first = CPyServerMiddlewareName + c_middleware.second.reset(new CPyServerMiddlewareFactory( + py_middleware, + start_call)) + c_options.get().middleware.push_back(c_middleware) + + vtable.list_flights = &_list_flights + vtable.get_flight_info = &_get_flight_info + vtable.get_schema = &_get_schema + vtable.do_put = &_do_put + vtable.do_get = &_do_get + vtable.do_exchange = &_do_exchange + vtable.list_actions = &_list_actions + vtable.do_action = &_do_action + + c_server = new PyFlightServer(self, vtable) + self.server.reset(c_server) + with nogil: + check_flight_status(c_server.Init(deref(c_options))) + cdef _FlightServerFinalizer finalizer = _FlightServerFinalizer() + finalizer.server = self.server + self.finalizer = weakref.finalize(self, finalizer.finalize) + + @property + def port(self): + """ + Get the port that this server is listening on. + + Returns a non-positive value if the operation is invalid + (e.g. init() was not called or server is listening on a domain + socket). + """ + return self.server.get().port() + + def list_flights(self, context, criteria): + """List flights available on this service. + + Applications should override this method to implement their + own behavior. The default method raises a NotImplementedError. + + Parameters + ---------- + context : ServerCallContext + Common contextual information. + criteria : bytes + Filter criteria provided by the client. + + Returns + ------- + iterator of FlightInfo + + """ + raise NotImplementedError + + def get_flight_info(self, context, descriptor): + """Get information about a flight. + + Applications should override this method to implement their + own behavior. The default method raises a NotImplementedError. + + Parameters + ---------- + context : ServerCallContext + Common contextual information. + descriptor : FlightDescriptor + The descriptor for the flight provided by the client. + + Returns + ------- + FlightInfo + + """ + raise NotImplementedError + + def get_schema(self, context, descriptor): + """Get the schema of a flight. + + Applications should override this method to implement their + own behavior. The default method raises a NotImplementedError. + + Parameters + ---------- + context : ServerCallContext + Common contextual information. + descriptor : FlightDescriptor + The descriptor for the flight provided by the client. + + Returns + ------- + Schema + + """ + raise NotImplementedError + + def do_put(self, context, descriptor, reader: MetadataRecordBatchReader, + writer: FlightMetadataWriter): + """Write data to a flight. + + Applications should override this method to implement their + own behavior. The default method raises a NotImplementedError. + + Parameters + ---------- + context : ServerCallContext + Common contextual information. + descriptor : FlightDescriptor + The descriptor for the flight provided by the client. + reader : MetadataRecordBatchReader + A reader for data uploaded by the client. + writer : FlightMetadataWriter + A writer to send responses to the client. + + """ + raise NotImplementedError + + def do_get(self, context, ticket): + """Write data to a flight. + + Applications should override this method to implement their + own behavior. The default method raises a NotImplementedError. + + Parameters + ---------- + context : ServerCallContext + Common contextual information. + ticket : Ticket + The ticket for the flight. + + Returns + ------- + FlightDataStream + A stream of data to send back to the client. + + """ + raise NotImplementedError + + def do_exchange(self, context, descriptor, reader, writer): + """Write data to a flight. + + Applications should override this method to implement their + own behavior. The default method raises a NotImplementedError. + + Parameters + ---------- + context : ServerCallContext + Common contextual information. + descriptor : FlightDescriptor + The descriptor for the flight provided by the client. + reader : MetadataRecordBatchReader + A reader for data uploaded by the client. + writer : MetadataRecordBatchWriter + A writer to send responses to the client. + + """ + raise NotImplementedError + + def list_actions(self, context): + """List custom actions available on this server. + + Applications should override this method to implement their + own behavior. The default method raises a NotImplementedError. + + Parameters + ---------- + context : ServerCallContext + Common contextual information. + + Returns + ------- + iterator of ActionType or tuple + + """ + raise NotImplementedError + + def do_action(self, context, action): + """Execute a custom action. + + This method should return an iterator, or it should be a + generator. Applications should override this method to + implement their own behavior. The default method raises a + NotImplementedError. + + Parameters + ---------- + context : ServerCallContext + Common contextual information. + action : Action + The action to execute. + + Returns + ------- + iterator of bytes + + """ + raise NotImplementedError + + def serve(self): + """Block until the server shuts down. + + This method only returns if shutdown() is called or a signal is + received. + """ + if self.server.get() == nullptr: + raise ValueError("run() on uninitialized FlightServerBase") + with nogil: + check_flight_status(self.server.get().ServeWithSignals()) + + def run(self): + """Block until the server shuts down. + + .. deprecated:: 0.15.0 + Use the ``FlightServer.serve`` method instead + """ + warnings.warn("The 'FlightServer.run' method is deprecated, use " + "FlightServer.serve method instead") + self.serve() + + def shutdown(self): + """Shut down the server, blocking until current requests finish. + + Do not call this directly from the implementation of a Flight + method, as then the server will block forever waiting for that + request to finish. Instead, call this method from a background + thread. + + This method should only be called once. + """ + # Must not hold the GIL: shutdown waits for pending RPCs to + # complete. Holding the GIL means Python-implemented Flight + # methods will never get to run, so this will hang + # indefinitely. + if self.server.get() == nullptr: + raise ValueError("shutdown() on uninitialized FlightServerBase") + with nogil: + check_flight_status(self.server.get().Shutdown()) + + def wait(self): + """Block until server is terminated with shutdown.""" + with nogil: + self.server.get().Wait() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + if self.finalizer: + self.finalizer() + + +def connect(location, **kwargs): + """ + Connect to a Flight server. + + Parameters + ---------- + location : str, tuple, or Location + Location to connect to. Either a URI like "grpc://localhost:port", + a tuple of (host, port), or a Location instance. + tls_root_certs : bytes or None + PEM-encoded. + cert_chain: str or None + If provided, enables TLS mutual authentication. + private_key: str or None + If provided, enables TLS mutual authentication. + override_hostname : str or None + Override the hostname checked by TLS. Insecure, use with caution. + middleware : list or None + A list of ClientMiddlewareFactory instances to apply. + write_size_limit_bytes : int or None + A soft limit on the size of a data payload sent to the + server. Enabled if positive. If enabled, writing a record + batch that (when serialized) exceeds this limit will raise an + exception; the client can retry the write with a smaller + batch. + disable_server_verification : boolean or None + Disable verifying the server when using TLS. + Insecure, use with caution. + generic_options : list or None + A list of generic (string, int or string) options to pass to + the underlying transport. + + Returns + ------- + client : FlightClient + """ + return FlightClient(location, **kwargs) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/_gcsfs.pyx b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/_gcsfs.pyx new file mode 100644 index 0000000000000000000000000000000000000000..5e69413cea953639e36ba5485cb383b88193748b --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/_gcsfs.pyx @@ -0,0 +1,212 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# cython: language_level = 3 + +from cython cimport binding + +from pyarrow.lib cimport (pyarrow_wrap_metadata, + pyarrow_unwrap_metadata) +from pyarrow.lib import frombytes, tobytes, ensure_metadata +from pyarrow.includes.common cimport * +from pyarrow.includes.libarrow cimport * +from pyarrow.includes.libarrow_fs cimport * +from pyarrow._fs cimport FileSystem, TimePoint_to_ns, PyDateTime_to_TimePoint + +from datetime import datetime, timedelta, timezone + + +cdef class GcsFileSystem(FileSystem): + """ + Google Cloud Storage (GCS) backed FileSystem implementation + + By default uses the process described in https://google.aip.dev/auth/4110 + to resolve credentials. If not running on Google Cloud Platform (GCP), + this generally requires the environment variable + GOOGLE_APPLICATION_CREDENTIALS to point to a JSON file + containing credentials. + + Note: GCS buckets are special and the operations available on them may be + limited or more expensive than expected compared to local file systems. + + Note: When pickling a GcsFileSystem that uses default credentials, resolution + credentials are not stored in the serialized data. Therefore, when unpickling + it is assumed that the necessary credentials are in place for the target + process. + + Parameters + ---------- + anonymous : boolean, default False + Whether to connect anonymously. + If true, will not attempt to look up credentials using standard GCP + configuration methods. + access_token : str, default None + GCP access token. If provided, temporary credentials will be fetched by + assuming this role; also, a `credential_token_expiration` must be + specified as well. + target_service_account : str, default None + An optional service account to try to impersonate when accessing GCS. This + requires the specified credential user or service account to have the necessary + permissions. + credential_token_expiration : datetime, default None + Expiration for credential generated with an access token. Must be specified + if `access_token` is specified. + default_bucket_location : str, default 'US' + GCP region to create buckets in. + scheme : str, default 'https' + GCS connection transport scheme. + endpoint_override : str, default None + Override endpoint with a connect string such as "localhost:9000" + default_metadata : mapping or pyarrow.KeyValueMetadata, default None + Default metadata for `open_output_stream`. This will be ignored if + non-empty metadata is passed to `open_output_stream`. + retry_time_limit : timedelta, default None + Set the maximum amount of time the GCS client will attempt to retry + transient errors. Subsecond granularity is ignored. + project_id : str, default None + The GCP project identifier to use for creating buckets. + If not set, the library uses the GOOGLE_CLOUD_PROJECT environment + variable. Most I/O operations do not need a project id, only applications + that create new buckets need a project id. + """ + + cdef: + CGcsFileSystem* gcsfs + + def __init__(self, *, bint anonymous=False, access_token=None, + target_service_account=None, credential_token_expiration=None, + default_bucket_location='US', + scheme=None, + endpoint_override=None, + default_metadata=None, + retry_time_limit=None, + project_id=None): + cdef: + CGcsOptions options + shared_ptr[CGcsFileSystem] wrapped + double time_limit_seconds + + # Intentional use of truthiness because empty strings aren't valid and + # for reconstruction from pickling will give empty strings. + if anonymous and (target_service_account or access_token): + raise ValueError( + 'anonymous option is not compatible with target_service_account and ' + 'access_token' + ) + elif bool(access_token) != bool(credential_token_expiration): + raise ValueError( + 'access_token and credential_token_expiration must be ' + 'specified together' + ) + + elif anonymous: + options = CGcsOptions.Anonymous() + elif access_token: + if not isinstance(credential_token_expiration, datetime): + raise ValueError( + "credential_token_expiration must be a datetime") + options = CGcsOptions.FromAccessToken( + tobytes(access_token), + PyDateTime_to_TimePoint(credential_token_expiration)) + else: + options = CGcsOptions.Defaults() + + # Target service account requires base credentials so + # it is not part of the if/else chain above which only + # handles base credentials. + if target_service_account: + options = CGcsOptions.FromImpersonatedServiceAccount( + options.credentials, tobytes(target_service_account)) + + options.default_bucket_location = tobytes(default_bucket_location) + + if scheme is not None: + options.scheme = tobytes(scheme) + if endpoint_override is not None: + options.endpoint_override = tobytes(endpoint_override) + if default_metadata is not None: + options.default_metadata = pyarrow_unwrap_metadata( + ensure_metadata(default_metadata)) + if retry_time_limit is not None: + time_limit_seconds = retry_time_limit.total_seconds() + options.retry_limit_seconds = time_limit_seconds + if project_id is not None: + options.project_id = tobytes(project_id) + + with nogil: + wrapped = GetResultValue(CGcsFileSystem.Make(options)) + + self.init( wrapped) + + cdef init(self, const shared_ptr[CFileSystem]& wrapped): + FileSystem.init(self, wrapped) + self.gcsfs = wrapped.get() + + def _expiration_datetime_from_options(self): + expiration_ns = TimePoint_to_ns( + self.gcsfs.options().credentials.expiration()) + if expiration_ns == 0: + return None + return datetime.fromtimestamp(expiration_ns / 1.0e9, timezone.utc) + + @staticmethod + @binding(True) # Required for cython < 3 + def _reconstruct(kwargs): + # __reduce__ doesn't allow passing named arguments directly to the + # reconstructor, hence this wrapper. + return GcsFileSystem(**kwargs) + + def __reduce__(self): + cdef CGcsOptions opts = self.gcsfs.options() + service_account = frombytes(opts.credentials.target_service_account()) + expiration_dt = self._expiration_datetime_from_options() + retry_time_limit = None + if opts.retry_limit_seconds.has_value(): + retry_time_limit = timedelta( + seconds=opts.retry_limit_seconds.value()) + project_id = None + if opts.project_id.has_value(): + project_id = frombytes(opts.project_id.value()) + return ( + GcsFileSystem._reconstruct, (dict( + access_token=frombytes(opts.credentials.access_token()), + anonymous=opts.credentials.anonymous(), + credential_token_expiration=expiration_dt, + target_service_account=service_account, + scheme=frombytes(opts.scheme), + endpoint_override=frombytes(opts.endpoint_override), + default_bucket_location=frombytes( + opts.default_bucket_location), + default_metadata=pyarrow_wrap_metadata(opts.default_metadata), + retry_time_limit=retry_time_limit, + project_id=project_id + ),)) + + @property + def default_bucket_location(self): + """ + The GCP location this filesystem will write to. + """ + return frombytes(self.gcsfs.options().default_bucket_location) + + @property + def project_id(self): + """ + The GCP project id this filesystem will use. + """ + if self.gcsfs.options().project_id.has_value(): + return frombytes(self.gcsfs.options().project_id.value()) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/_generated_version.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/_generated_version.py new file mode 100644 index 0000000000000000000000000000000000000000..d3cd8202a74a12f1643c491c1ae37b77dba90a21 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/_generated_version.py @@ -0,0 +1,16 @@ +# file generated by setuptools_scm +# don't change, don't track in version control +TYPE_CHECKING = False +if TYPE_CHECKING: + from typing import Tuple, Union + VERSION_TUPLE = Tuple[Union[int, str], ...] +else: + VERSION_TUPLE = object + +version: str +__version__: str +__version_tuple__: VERSION_TUPLE +version_tuple: VERSION_TUPLE + +__version__ = version = '19.0.1' +__version_tuple__ = version_tuple = (19, 0, 1) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/_parquet.pxd b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/_parquet.pxd new file mode 100644 index 0000000000000000000000000000000000000000..d6aebd8284f4a2a0a54d7bcbc9cdccbb03c7ef83 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/_parquet.pxd @@ -0,0 +1,680 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# distutils: language = c++ +# cython: language_level = 3 + +from pyarrow.includes.common cimport * +from pyarrow.includes.libarrow cimport (CChunkedArray, CScalar, CSchema, CStatus, + CTable, CMemoryPool, CBuffer, + CKeyValueMetadata, CRandomAccessFile, + COutputStream, CCacheOptions, + TimeUnit, CRecordBatchReader) +from pyarrow.lib cimport _Weakrefable + + +cdef extern from "parquet/api/schema.h" namespace "parquet::schema" nogil: + cdef cppclass Node: + pass + + cdef cppclass GroupNode(Node): + pass + + cdef cppclass PrimitiveNode(Node): + pass + + cdef cppclass ColumnPath: + c_string ToDotString() + vector[c_string] ToDotVector() + + +cdef extern from "parquet/api/schema.h" namespace "parquet" nogil: + enum ParquetType" parquet::Type::type": + ParquetType_BOOLEAN" parquet::Type::BOOLEAN" + ParquetType_INT32" parquet::Type::INT32" + ParquetType_INT64" parquet::Type::INT64" + ParquetType_INT96" parquet::Type::INT96" + ParquetType_FLOAT" parquet::Type::FLOAT" + ParquetType_DOUBLE" parquet::Type::DOUBLE" + ParquetType_BYTE_ARRAY" parquet::Type::BYTE_ARRAY" + ParquetType_FIXED_LEN_BYTE_ARRAY" parquet::Type::FIXED_LEN_BYTE_ARRAY" + + enum ParquetLogicalTypeId" parquet::LogicalType::Type::type": + ParquetLogicalType_UNDEFINED" parquet::LogicalType::Type::UNDEFINED" + ParquetLogicalType_STRING" parquet::LogicalType::Type::STRING" + ParquetLogicalType_MAP" parquet::LogicalType::Type::MAP" + ParquetLogicalType_LIST" parquet::LogicalType::Type::LIST" + ParquetLogicalType_ENUM" parquet::LogicalType::Type::ENUM" + ParquetLogicalType_DECIMAL" parquet::LogicalType::Type::DECIMAL" + ParquetLogicalType_DATE" parquet::LogicalType::Type::DATE" + ParquetLogicalType_TIME" parquet::LogicalType::Type::TIME" + ParquetLogicalType_TIMESTAMP" parquet::LogicalType::Type::TIMESTAMP" + ParquetLogicalType_INT" parquet::LogicalType::Type::INT" + ParquetLogicalType_FLOAT16" parquet::LogicalType::Type::FLOAT16" + ParquetLogicalType_JSON" parquet::LogicalType::Type::JSON" + ParquetLogicalType_BSON" parquet::LogicalType::Type::BSON" + ParquetLogicalType_UUID" parquet::LogicalType::Type::UUID" + ParquetLogicalType_NONE" parquet::LogicalType::Type::NONE" + + enum ParquetTimeUnit" parquet::LogicalType::TimeUnit::unit": + ParquetTimeUnit_UNKNOWN" parquet::LogicalType::TimeUnit::UNKNOWN" + ParquetTimeUnit_MILLIS" parquet::LogicalType::TimeUnit::MILLIS" + ParquetTimeUnit_MICROS" parquet::LogicalType::TimeUnit::MICROS" + ParquetTimeUnit_NANOS" parquet::LogicalType::TimeUnit::NANOS" + + enum ParquetConvertedType" parquet::ConvertedType::type": + ParquetConvertedType_NONE" parquet::ConvertedType::NONE" + ParquetConvertedType_UTF8" parquet::ConvertedType::UTF8" + ParquetConvertedType_MAP" parquet::ConvertedType::MAP" + ParquetConvertedType_MAP_KEY_VALUE \ + " parquet::ConvertedType::MAP_KEY_VALUE" + ParquetConvertedType_LIST" parquet::ConvertedType::LIST" + ParquetConvertedType_ENUM" parquet::ConvertedType::ENUM" + ParquetConvertedType_DECIMAL" parquet::ConvertedType::DECIMAL" + ParquetConvertedType_DATE" parquet::ConvertedType::DATE" + ParquetConvertedType_TIME_MILLIS" parquet::ConvertedType::TIME_MILLIS" + ParquetConvertedType_TIME_MICROS" parquet::ConvertedType::TIME_MICROS" + ParquetConvertedType_TIMESTAMP_MILLIS \ + " parquet::ConvertedType::TIMESTAMP_MILLIS" + ParquetConvertedType_TIMESTAMP_MICROS \ + " parquet::ConvertedType::TIMESTAMP_MICROS" + ParquetConvertedType_UINT_8" parquet::ConvertedType::UINT_8" + ParquetConvertedType_UINT_16" parquet::ConvertedType::UINT_16" + ParquetConvertedType_UINT_32" parquet::ConvertedType::UINT_32" + ParquetConvertedType_UINT_64" parquet::ConvertedType::UINT_64" + ParquetConvertedType_INT_8" parquet::ConvertedType::INT_8" + ParquetConvertedType_INT_16" parquet::ConvertedType::INT_16" + ParquetConvertedType_INT_32" parquet::ConvertedType::INT_32" + ParquetConvertedType_INT_64" parquet::ConvertedType::INT_64" + ParquetConvertedType_JSON" parquet::ConvertedType::JSON" + ParquetConvertedType_BSON" parquet::ConvertedType::BSON" + ParquetConvertedType_INTERVAL" parquet::ConvertedType::INTERVAL" + + enum ParquetRepetition" parquet::Repetition::type": + ParquetRepetition_REQUIRED" parquet::REPETITION::REQUIRED" + ParquetRepetition_OPTIONAL" parquet::REPETITION::OPTIONAL" + ParquetRepetition_REPEATED" parquet::REPETITION::REPEATED" + + enum ParquetEncoding" parquet::Encoding::type": + ParquetEncoding_PLAIN" parquet::Encoding::PLAIN" + ParquetEncoding_PLAIN_DICTIONARY" parquet::Encoding::PLAIN_DICTIONARY" + ParquetEncoding_RLE" parquet::Encoding::RLE" + ParquetEncoding_BIT_PACKED" parquet::Encoding::BIT_PACKED" + ParquetEncoding_DELTA_BINARY_PACKED \ + " parquet::Encoding::DELTA_BINARY_PACKED" + ParquetEncoding_DELTA_LENGTH_BYTE_ARRAY \ + " parquet::Encoding::DELTA_LENGTH_BYTE_ARRAY" + ParquetEncoding_DELTA_BYTE_ARRAY" parquet::Encoding::DELTA_BYTE_ARRAY" + ParquetEncoding_RLE_DICTIONARY" parquet::Encoding::RLE_DICTIONARY" + ParquetEncoding_BYTE_STREAM_SPLIT \ + " parquet::Encoding::BYTE_STREAM_SPLIT" + + enum ParquetCompression" parquet::Compression::type": + ParquetCompression_UNCOMPRESSED" parquet::Compression::UNCOMPRESSED" + ParquetCompression_SNAPPY" parquet::Compression::SNAPPY" + ParquetCompression_GZIP" parquet::Compression::GZIP" + ParquetCompression_LZO" parquet::Compression::LZO" + ParquetCompression_BROTLI" parquet::Compression::BROTLI" + ParquetCompression_LZ4" parquet::Compression::LZ4" + ParquetCompression_ZSTD" parquet::Compression::ZSTD" + + enum ParquetVersion" parquet::ParquetVersion::type": + ParquetVersion_V1" parquet::ParquetVersion::PARQUET_1_0" + ParquetVersion_V2_0" parquet::ParquetVersion::PARQUET_2_0" + ParquetVersion_V2_4" parquet::ParquetVersion::PARQUET_2_4" + ParquetVersion_V2_6" parquet::ParquetVersion::PARQUET_2_6" + + enum ParquetSortOrder" parquet::SortOrder::type": + ParquetSortOrder_SIGNED" parquet::SortOrder::SIGNED" + ParquetSortOrder_UNSIGNED" parquet::SortOrder::UNSIGNED" + ParquetSortOrder_UNKNOWN" parquet::SortOrder::UNKNOWN" + + cdef cppclass CParquetLogicalType" parquet::LogicalType": + c_string ToString() const + c_string ToJSON() const + ParquetLogicalTypeId type() const + + cdef cppclass CParquetDecimalType \ + " parquet::DecimalLogicalType"(CParquetLogicalType): + int32_t precision() const + int32_t scale() const + + cdef cppclass CParquetIntType \ + " parquet::IntLogicalType"(CParquetLogicalType): + int bit_width() const + c_bool is_signed() const + + cdef cppclass CParquetTimeType \ + " parquet::TimeLogicalType"(CParquetLogicalType): + c_bool is_adjusted_to_utc() const + ParquetTimeUnit time_unit() const + + cdef cppclass CParquetTimestampType \ + " parquet::TimestampLogicalType"(CParquetLogicalType): + c_bool is_adjusted_to_utc() const + ParquetTimeUnit time_unit() const + + cdef cppclass ColumnDescriptor" parquet::ColumnDescriptor": + c_bool Equals(const ColumnDescriptor& other) + + shared_ptr[ColumnPath] path() + int16_t max_definition_level() + int16_t max_repetition_level() + + ParquetType physical_type() + const shared_ptr[const CParquetLogicalType]& logical_type() + ParquetConvertedType converted_type() + const c_string& name() + int type_length() + int type_precision() + int type_scale() + + cdef cppclass SchemaDescriptor: + const ColumnDescriptor* Column(int i) + shared_ptr[Node] schema() + GroupNode* group() + c_bool Equals(const SchemaDescriptor& other) + c_string ToString() + int num_columns() + + cdef c_string FormatStatValue(ParquetType parquet_type, c_string val) + + enum ParquetCipher" parquet::ParquetCipher::type": + ParquetCipher_AES_GCM_V1" parquet::ParquetCipher::AES_GCM_V1" + ParquetCipher_AES_GCM_CTR_V1" parquet::ParquetCipher::AES_GCM_CTR_V1" + + struct AadMetadata: + c_string aad_prefix + c_string aad_file_unique + c_bool supply_aad_prefix + + struct EncryptionAlgorithm: + ParquetCipher algorithm + AadMetadata aad + +cdef extern from "parquet/api/reader.h" namespace "parquet" nogil: + cdef cppclass ColumnReader: + pass + + cdef cppclass BoolReader(ColumnReader): + pass + + cdef cppclass Int32Reader(ColumnReader): + pass + + cdef cppclass Int64Reader(ColumnReader): + pass + + cdef cppclass Int96Reader(ColumnReader): + pass + + cdef cppclass FloatReader(ColumnReader): + pass + + cdef cppclass DoubleReader(ColumnReader): + pass + + cdef cppclass ByteArrayReader(ColumnReader): + pass + + cdef cppclass RowGroupReader: + pass + + cdef cppclass CEncodedStatistics" parquet::EncodedStatistics": + const c_string& max() const + const c_string& min() const + int64_t null_count + int64_t distinct_count + bint has_min + bint has_max + bint has_null_count + bint has_distinct_count + + cdef cppclass ParquetByteArray" parquet::ByteArray": + uint32_t len + const uint8_t* ptr + + cdef cppclass ParquetFLBA" parquet::FLBA": + const uint8_t* ptr + + cdef cppclass CStatistics" parquet::Statistics": + int64_t null_count() const + int64_t distinct_count() const + int64_t num_values() const + bint HasMinMax() + bint HasNullCount() + bint HasDistinctCount() + c_bool Equals(const CStatistics&) const + void Reset() + c_string EncodeMin() + c_string EncodeMax() + CEncodedStatistics Encode() + void SetComparator() + ParquetType physical_type() const + const ColumnDescriptor* descr() const + + cdef cppclass CBoolStatistics" parquet::BoolStatistics"(CStatistics): + c_bool min() + c_bool max() + + cdef cppclass CInt32Statistics" parquet::Int32Statistics"(CStatistics): + int32_t min() + int32_t max() + + cdef cppclass CInt64Statistics" parquet::Int64Statistics"(CStatistics): + int64_t min() + int64_t max() + + cdef cppclass CFloatStatistics" parquet::FloatStatistics"(CStatistics): + float min() + float max() + + cdef cppclass CDoubleStatistics" parquet::DoubleStatistics"(CStatistics): + double min() + double max() + + cdef cppclass CByteArrayStatistics \ + " parquet::ByteArrayStatistics"(CStatistics): + ParquetByteArray min() + ParquetByteArray max() + + cdef cppclass CFLBAStatistics" parquet::FLBAStatistics"(CStatistics): + ParquetFLBA min() + ParquetFLBA max() + + cdef cppclass CColumnCryptoMetaData" parquet::ColumnCryptoMetaData": + shared_ptr[ColumnPath] path_in_schema() const + c_bool encrypted_with_footer_key() const + const c_string& key_metadata() const + + cdef cppclass ParquetIndexLocation" parquet::IndexLocation": + int64_t offset + int32_t length + + cdef cppclass CColumnChunkMetaData" parquet::ColumnChunkMetaData": + int64_t file_offset() const + const c_string& file_path() const + + c_bool is_metadata_set() const + ParquetType type() const + int64_t num_values() const + shared_ptr[ColumnPath] path_in_schema() const + bint is_stats_set() const + shared_ptr[CStatistics] statistics() const + ParquetCompression compression() const + const vector[ParquetEncoding]& encodings() const + c_bool Equals(const CColumnChunkMetaData&) const + + int64_t has_dictionary_page() const + int64_t dictionary_page_offset() const + int64_t data_page_offset() const + int64_t index_page_offset() const + int64_t total_compressed_size() const + int64_t total_uncompressed_size() const + unique_ptr[CColumnCryptoMetaData] crypto_metadata() const + optional[ParquetIndexLocation] GetColumnIndexLocation() const + optional[ParquetIndexLocation] GetOffsetIndexLocation() const + shared_ptr[const CKeyValueMetadata] key_value_metadata() const + + struct CSortingColumn" parquet::SortingColumn": + int column_idx + c_bool descending + c_bool nulls_first + + cdef cppclass CRowGroupMetaData" parquet::RowGroupMetaData": + c_bool Equals(const CRowGroupMetaData&) const + int num_columns() const + int64_t num_rows() const + int64_t total_byte_size() const + vector[CSortingColumn] sorting_columns() const + unique_ptr[CColumnChunkMetaData] ColumnChunk(int i) const + + cdef cppclass CFileMetaData" parquet::FileMetaData": + c_bool Equals(const CFileMetaData&) const + uint32_t size() + int num_columns() + int64_t num_rows() + int num_row_groups() + ParquetVersion version() + const c_string created_by() + int num_schema_elements() + + void set_file_path(const c_string& path) + void AppendRowGroups(const CFileMetaData& other) except + + + unique_ptr[CRowGroupMetaData] RowGroup(int i) + const SchemaDescriptor* schema() + shared_ptr[const CKeyValueMetadata] key_value_metadata() const + void WriteTo(COutputStream* dst) const + + inline c_bool is_encryption_algorithm_set() const + inline EncryptionAlgorithm encryption_algorithm() const + inline const c_string& footer_signing_key_metadata() const + + cdef shared_ptr[CFileMetaData] CFileMetaData_Make \ + " parquet::FileMetaData::Make"(const void* serialized_metadata, + uint32_t* metadata_len) + + cdef cppclass CReaderProperties" parquet::ReaderProperties": + c_bool is_buffered_stream_enabled() const + void enable_buffered_stream() + void disable_buffered_stream() + + void set_buffer_size(int64_t buf_size) + int64_t buffer_size() const + + void set_thrift_string_size_limit(int32_t size) + int32_t thrift_string_size_limit() const + + void set_thrift_container_size_limit(int32_t size) + int32_t thrift_container_size_limit() const + + void file_decryption_properties(shared_ptr[CFileDecryptionProperties] + decryption) + shared_ptr[CFileDecryptionProperties] file_decryption_properties() \ + const + + c_bool page_checksum_verification() const + void set_page_checksum_verification(c_bool check_crc) + + CReaderProperties default_reader_properties() + + cdef cppclass ArrowReaderProperties: + ArrowReaderProperties() + void set_read_dictionary(int column_index, c_bool read_dict) + c_bool read_dictionary() + void set_batch_size(int64_t batch_size) + int64_t batch_size() + void set_pre_buffer(c_bool pre_buffer) + c_bool pre_buffer() const + void set_cache_options(CCacheOptions options) + CCacheOptions cache_options() const + void set_coerce_int96_timestamp_unit(TimeUnit unit) + TimeUnit coerce_int96_timestamp_unit() const + + ArrowReaderProperties default_arrow_reader_properties() + + cdef cppclass ParquetFileReader: + shared_ptr[CFileMetaData] metadata() + + +cdef extern from "parquet/api/writer.h" namespace "parquet" nogil: + cdef cppclass WriterProperties: + cppclass Builder: + Builder* data_page_version(ParquetDataPageVersion version) + Builder* version(ParquetVersion version) + Builder* compression(ParquetCompression codec) + Builder* compression(const c_string& path, + ParquetCompression codec) + Builder* compression_level(int compression_level) + Builder* compression_level(const c_string& path, + int compression_level) + Builder* encryption( + shared_ptr[CFileEncryptionProperties] + file_encryption_properties) + Builder* disable_dictionary() + Builder* enable_dictionary() + Builder* enable_dictionary(const c_string& path) + Builder* set_sorting_columns(vector[CSortingColumn] sorting_columns) + Builder* disable_statistics() + Builder* enable_statistics() + Builder* enable_statistics(const c_string& path) + Builder* enable_store_decimal_as_integer() + Builder* disable_store_decimal_as_integer() + Builder* data_pagesize(int64_t size) + Builder* encoding(ParquetEncoding encoding) + Builder* encoding(const c_string& path, + ParquetEncoding encoding) + Builder* max_row_group_length(int64_t size) + Builder* write_batch_size(int64_t batch_size) + Builder* dictionary_pagesize_limit(int64_t dictionary_pagesize_limit) + Builder* enable_write_page_index() + Builder* disable_write_page_index() + Builder* enable_page_checksum() + Builder* disable_page_checksum() + shared_ptr[WriterProperties] build() + + cdef cppclass ArrowWriterProperties: + cppclass Builder: + Builder() + Builder* disable_deprecated_int96_timestamps() + Builder* enable_deprecated_int96_timestamps() + Builder* coerce_timestamps(TimeUnit unit) + Builder* allow_truncated_timestamps() + Builder* disallow_truncated_timestamps() + Builder* store_schema() + Builder* enable_compliant_nested_types() + Builder* disable_compliant_nested_types() + Builder* set_engine_version(ArrowWriterEngineVersion version) + shared_ptr[ArrowWriterProperties] build() + c_bool support_deprecated_int96_timestamps() + + +cdef extern from "parquet/arrow/reader.h" namespace "parquet::arrow" nogil: + cdef cppclass FileReader: + FileReader(CMemoryPool* pool, unique_ptr[ParquetFileReader] reader) + + CStatus GetSchema(shared_ptr[CSchema]* out) + + CStatus ReadColumn(int i, shared_ptr[CChunkedArray]* out) + CStatus ReadSchemaField(int i, shared_ptr[CChunkedArray]* out) + + int num_row_groups() + CStatus ReadRowGroup(int i, shared_ptr[CTable]* out) + CStatus ReadRowGroup(int i, const vector[int]& column_indices, + shared_ptr[CTable]* out) + + CStatus ReadRowGroups(const vector[int]& row_groups, + shared_ptr[CTable]* out) + CStatus ReadRowGroups(const vector[int]& row_groups, + const vector[int]& column_indices, + shared_ptr[CTable]* out) + + CStatus GetRecordBatchReader(const vector[int]& row_group_indices, + const vector[int]& column_indices, + unique_ptr[CRecordBatchReader]* out) + CStatus GetRecordBatchReader(const vector[int]& row_group_indices, + unique_ptr[CRecordBatchReader]* out) + + CStatus ReadTable(shared_ptr[CTable]* out) + CStatus ReadTable(const vector[int]& column_indices, + shared_ptr[CTable]* out) + + CStatus ScanContents(vector[int] columns, int32_t column_batch_size, + int64_t* num_rows) + + const ParquetFileReader* parquet_reader() + + void set_use_threads(c_bool use_threads) + + void set_batch_size(int64_t batch_size) + + cdef cppclass FileReaderBuilder: + FileReaderBuilder() + CStatus Open(const shared_ptr[CRandomAccessFile]& file, + const CReaderProperties& properties, + const shared_ptr[CFileMetaData]& metadata) + + ParquetFileReader* raw_reader() + FileReaderBuilder* memory_pool(CMemoryPool*) + FileReaderBuilder* properties(const ArrowReaderProperties&) + CStatus Build(unique_ptr[FileReader]* out) + + CStatus FromParquetSchema( + const SchemaDescriptor* parquet_schema, + const ArrowReaderProperties& properties, + const shared_ptr[const CKeyValueMetadata]& key_value_metadata, + shared_ptr[CSchema]* out) + + CStatus StatisticsAsScalars(const CStatistics& Statistics, + shared_ptr[CScalar]* min, + shared_ptr[CScalar]* max) + +cdef extern from "parquet/arrow/schema.h" namespace "parquet::arrow" nogil: + + CStatus ToParquetSchema( + const CSchema* arrow_schema, + const WriterProperties& properties, + const ArrowWriterProperties& arrow_properties, + shared_ptr[SchemaDescriptor]* out) + + +cdef extern from "parquet/properties.h" namespace "parquet" nogil: + cdef enum ArrowWriterEngineVersion: + V1 "parquet::ArrowWriterProperties::V1", + V2 "parquet::ArrowWriterProperties::V2" + + cdef cppclass ParquetDataPageVersion: + pass + + cdef ParquetDataPageVersion ParquetDataPageVersion_V1 \ + " parquet::ParquetDataPageVersion::V1" + cdef ParquetDataPageVersion ParquetDataPageVersion_V2 \ + " parquet::ParquetDataPageVersion::V2" + +cdef extern from "parquet/arrow/writer.h" namespace "parquet::arrow" nogil: + cdef cppclass FileWriter: + + @staticmethod + CResult[unique_ptr[FileWriter]] Open(const CSchema& schema, CMemoryPool* pool, + const shared_ptr[COutputStream]& sink, + const shared_ptr[WriterProperties]& properties, + const shared_ptr[ArrowWriterProperties]& arrow_properties) + + CStatus WriteTable(const CTable& table, int64_t chunk_size) + CStatus NewRowGroup(int64_t chunk_size) + CStatus Close() + CStatus AddKeyValueMetadata(const shared_ptr[const CKeyValueMetadata]& key_value_metadata) + + const shared_ptr[CFileMetaData] metadata() const + + CStatus WriteMetaDataFile( + const CFileMetaData& file_metadata, + const COutputStream* sink) + +cdef class FileEncryptionProperties: + """File-level encryption properties for the low-level API""" + cdef: + shared_ptr[CFileEncryptionProperties] properties + + @staticmethod + cdef inline FileEncryptionProperties wrap( + shared_ptr[CFileEncryptionProperties] properties): + + result = FileEncryptionProperties() + result.properties = properties + return result + + cdef inline shared_ptr[CFileEncryptionProperties] unwrap(self): + return self.properties + +cdef shared_ptr[WriterProperties] _create_writer_properties( + use_dictionary=*, + compression=*, + version=*, + write_statistics=*, + data_page_size=*, + compression_level=*, + use_byte_stream_split=*, + column_encoding=*, + data_page_version=*, + FileEncryptionProperties encryption_properties=*, + write_batch_size=*, + dictionary_pagesize_limit=*, + write_page_index=*, + write_page_checksum=*, + sorting_columns=*, + store_decimal_as_integer=*, +) except * + + +cdef shared_ptr[ArrowWriterProperties] _create_arrow_writer_properties( + use_deprecated_int96_timestamps=*, + coerce_timestamps=*, + allow_truncated_timestamps=*, + writer_engine_version=*, + use_compliant_nested_type=*, + store_schema=*, +) except * + +cdef class ParquetSchema(_Weakrefable): + cdef: + FileMetaData parent # the FileMetaData owning the SchemaDescriptor + const SchemaDescriptor* schema + +cdef class FileMetaData(_Weakrefable): + cdef: + shared_ptr[CFileMetaData] sp_metadata + CFileMetaData* _metadata + ParquetSchema _schema + + cdef inline init(self, const shared_ptr[CFileMetaData]& metadata): + self.sp_metadata = metadata + self._metadata = metadata.get() + +cdef class RowGroupMetaData(_Weakrefable): + cdef: + int index # for pickling support + unique_ptr[CRowGroupMetaData] up_metadata + CRowGroupMetaData* metadata + FileMetaData parent + +cdef class ColumnChunkMetaData(_Weakrefable): + cdef: + unique_ptr[CColumnChunkMetaData] up_metadata + CColumnChunkMetaData* metadata + RowGroupMetaData parent + + cdef inline init(self, RowGroupMetaData parent, int i): + self.up_metadata = parent.metadata.ColumnChunk(i) + self.metadata = self.up_metadata.get() + self.parent = parent + +cdef class Statistics(_Weakrefable): + cdef: + shared_ptr[CStatistics] statistics + ColumnChunkMetaData parent + + cdef inline init(self, const shared_ptr[CStatistics]& statistics, + ColumnChunkMetaData parent): + self.statistics = statistics + self.parent = parent + +cdef extern from "parquet/encryption/encryption.h" namespace "parquet" nogil: + cdef cppclass CFileDecryptionProperties\ + " parquet::FileDecryptionProperties": + pass + + cdef cppclass CFileEncryptionProperties\ + " parquet::FileEncryptionProperties": + pass + +cdef class FileDecryptionProperties: + """File-level decryption properties for the low-level API""" + cdef: + shared_ptr[CFileDecryptionProperties] properties + + @staticmethod + cdef inline FileDecryptionProperties wrap( + shared_ptr[CFileDecryptionProperties] properties): + + result = FileDecryptionProperties() + result.properties = properties + return result + + cdef inline shared_ptr[CFileDecryptionProperties] unwrap(self): + return self.properties diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/_parquet_encryption.pxd b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/_parquet_encryption.pxd new file mode 100644 index 0000000000000000000000000000000000000000..d52669501a4044838e576d3dac8f8a422874eaa6 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/_parquet_encryption.pxd @@ -0,0 +1,56 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# distutils: language = c++ +# cython: language_level = 3 + +from pyarrow.includes.common cimport * +from pyarrow.includes.libparquet_encryption cimport * +from pyarrow._parquet cimport (ParquetCipher, + CFileEncryptionProperties, + CFileDecryptionProperties, + FileEncryptionProperties, + FileDecryptionProperties, + ParquetCipher_AES_GCM_V1, + ParquetCipher_AES_GCM_CTR_V1) +from pyarrow.lib cimport _Weakrefable + +cdef class CryptoFactory(_Weakrefable): + cdef shared_ptr[CPyCryptoFactory] factory + cdef init(self, callable_client_factory) + cdef inline shared_ptr[CPyCryptoFactory] unwrap(self) + +cdef class EncryptionConfiguration(_Weakrefable): + cdef shared_ptr[CEncryptionConfiguration] configuration + cdef inline shared_ptr[CEncryptionConfiguration] unwrap(self) nogil + +cdef class DecryptionConfiguration(_Weakrefable): + cdef shared_ptr[CDecryptionConfiguration] configuration + cdef inline shared_ptr[CDecryptionConfiguration] unwrap(self) nogil + +cdef class KmsConnectionConfig(_Weakrefable): + cdef shared_ptr[CKmsConnectionConfig] configuration + cdef inline shared_ptr[CKmsConnectionConfig] unwrap(self) nogil + + @staticmethod + cdef wrap(const CKmsConnectionConfig& config) + + +cdef shared_ptr[CCryptoFactory] pyarrow_unwrap_cryptofactory(object crypto_factory) except * +cdef shared_ptr[CKmsConnectionConfig] pyarrow_unwrap_kmsconnectionconfig(object kmsconnectionconfig) except * +cdef shared_ptr[CEncryptionConfiguration] pyarrow_unwrap_encryptionconfig(object encryptionconfig) except * +cdef shared_ptr[CDecryptionConfiguration] pyarrow_unwrap_decryptionconfig(object decryptionconfig) except * diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/_parquet_encryption.pyx b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/_parquet_encryption.pyx new file mode 100644 index 0000000000000000000000000000000000000000..d0a9a6612328c547bc724d6fcf2d37ae5e7badd3 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/_parquet_encryption.pyx @@ -0,0 +1,484 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# cython: profile=False +# distutils: language = c++ + +from datetime import timedelta + +from cython.operator cimport dereference as deref +from libcpp.memory cimport shared_ptr +from pyarrow.includes.common cimport * +from pyarrow.includes.libarrow cimport * +from pyarrow.lib cimport _Weakrefable +from pyarrow.lib import tobytes, frombytes + + +cdef ParquetCipher cipher_from_name(name): + name = name.upper() + if name == 'AES_GCM_V1': + return ParquetCipher_AES_GCM_V1 + elif name == 'AES_GCM_CTR_V1': + return ParquetCipher_AES_GCM_CTR_V1 + else: + raise ValueError(f'Invalid cipher name: {name!r}') + + +cdef cipher_to_name(ParquetCipher cipher): + if ParquetCipher_AES_GCM_V1 == cipher: + return 'AES_GCM_V1' + elif ParquetCipher_AES_GCM_CTR_V1 == cipher: + return 'AES_GCM_CTR_V1' + else: + raise ValueError('Invalid cipher value: {0}'.format(cipher)) + +cdef class EncryptionConfiguration(_Weakrefable): + """Configuration of the encryption, such as which columns to encrypt""" + # Avoid mistakingly creating attributes + __slots__ = () + + def __init__(self, footer_key, *, column_keys=None, + encryption_algorithm=None, + plaintext_footer=None, double_wrapping=None, + cache_lifetime=None, internal_key_material=None, + data_key_length_bits=None): + self.configuration.reset( + new CEncryptionConfiguration(tobytes(footer_key))) + if column_keys is not None: + self.column_keys = column_keys + if encryption_algorithm is not None: + self.encryption_algorithm = encryption_algorithm + if plaintext_footer is not None: + self.plaintext_footer = plaintext_footer + if double_wrapping is not None: + self.double_wrapping = double_wrapping + if cache_lifetime is not None: + self.cache_lifetime = cache_lifetime + if internal_key_material is not None: + self.internal_key_material = internal_key_material + if data_key_length_bits is not None: + self.data_key_length_bits = data_key_length_bits + + @property + def footer_key(self): + """ID of the master key for footer encryption/signing""" + return frombytes(self.configuration.get().footer_key) + + @property + def column_keys(self): + """ + List of columns to encrypt, with master key IDs. + """ + column_keys_str = frombytes(self.configuration.get().column_keys) + # Convert from "masterKeyID:colName,colName;masterKeyID:colName..." + # (see HIVE-21848) to dictionary of master key ID to column name lists + column_keys_to_key_list_str = dict(subString.replace(" ", "").split( + ":") for subString in column_keys_str.split(";")) + column_keys_dict = {k: v.split( + ",") for k, v in column_keys_to_key_list_str.items()} + return column_keys_dict + + @column_keys.setter + def column_keys(self, dict value): + if value is not None: + # convert a dictionary such as + # '{"key1": ["col1 ", "col2"], "key2": ["col3 ", "col4"]}'' + # to the string defined by the spec + # 'key1: col1 , col2; key2: col3 , col4' + column_keys = "; ".join( + ["{}: {}".format(k, ", ".join(v)) for k, v in value.items()]) + self.configuration.get().column_keys = tobytes(column_keys) + + @property + def encryption_algorithm(self): + """Parquet encryption algorithm. + Can be "AES_GCM_V1" (default), or "AES_GCM_CTR_V1".""" + return cipher_to_name(self.configuration.get().encryption_algorithm) + + @encryption_algorithm.setter + def encryption_algorithm(self, value): + cipher = cipher_from_name(value) + self.configuration.get().encryption_algorithm = cipher + + @property + def plaintext_footer(self): + """Write files with plaintext footer.""" + return self.configuration.get().plaintext_footer + + @plaintext_footer.setter + def plaintext_footer(self, value): + self.configuration.get().plaintext_footer = value + + @property + def double_wrapping(self): + """Use double wrapping - where data encryption keys (DEKs) are + encrypted with key encryption keys (KEKs), which in turn are + encrypted with master keys. + If set to false, use single wrapping - where DEKs are + encrypted directly with master keys.""" + return self.configuration.get().double_wrapping + + @double_wrapping.setter + def double_wrapping(self, value): + self.configuration.get().double_wrapping = value + + @property + def cache_lifetime(self): + """Lifetime of cached entities (key encryption keys, + local wrapping keys, KMS client objects).""" + return timedelta( + seconds=self.configuration.get().cache_lifetime_seconds) + + @cache_lifetime.setter + def cache_lifetime(self, value): + if not isinstance(value, timedelta): + raise TypeError("cache_lifetime should be a timedelta") + self.configuration.get().cache_lifetime_seconds = value.total_seconds() + + @property + def internal_key_material(self): + """Store key material inside Parquet file footers; this mode doesn’t + produce additional files. If set to false, key material is stored in + separate files in the same folder, which enables key rotation for + immutable Parquet files.""" + return self.configuration.get().internal_key_material + + @internal_key_material.setter + def internal_key_material(self, value): + self.configuration.get().internal_key_material = value + + @property + def data_key_length_bits(self): + """Length of data encryption keys (DEKs), randomly generated by parquet key + management tools. Can be 128, 192 or 256 bits.""" + return self.configuration.get().data_key_length_bits + + @data_key_length_bits.setter + def data_key_length_bits(self, value): + self.configuration.get().data_key_length_bits = value + + cdef inline shared_ptr[CEncryptionConfiguration] unwrap(self) nogil: + return self.configuration + + +cdef class DecryptionConfiguration(_Weakrefable): + """Configuration of the decryption, such as cache timeout.""" + # Avoid mistakingly creating attributes + __slots__ = () + + def __init__(self, *, cache_lifetime=None): + self.configuration.reset(new CDecryptionConfiguration()) + + @property + def cache_lifetime(self): + """Lifetime of cached entities (key encryption keys, + local wrapping keys, KMS client objects).""" + return timedelta( + seconds=self.configuration.get().cache_lifetime_seconds) + + @cache_lifetime.setter + def cache_lifetime(self, value): + self.configuration.get().cache_lifetime_seconds = value.total_seconds() + + cdef inline shared_ptr[CDecryptionConfiguration] unwrap(self) nogil: + return self.configuration + + +cdef class KmsConnectionConfig(_Weakrefable): + """Configuration of the connection to the Key Management Service (KMS)""" + # Avoid mistakingly creating attributes + __slots__ = () + + def __init__(self, *, kms_instance_id=None, kms_instance_url=None, + key_access_token=None, custom_kms_conf=None): + self.configuration.reset(new CKmsConnectionConfig()) + if kms_instance_id is not None: + self.kms_instance_id = kms_instance_id + if kms_instance_url is not None: + self.kms_instance_url = kms_instance_url + if key_access_token is None: + self.key_access_token = b'DEFAULT' + else: + self.key_access_token = key_access_token + if custom_kms_conf is not None: + self.custom_kms_conf = custom_kms_conf + + @property + def kms_instance_id(self): + """ID of the KMS instance that will be used for encryption + (if multiple KMS instances are available).""" + return frombytes(self.configuration.get().kms_instance_id) + + @kms_instance_id.setter + def kms_instance_id(self, value): + self.configuration.get().kms_instance_id = tobytes(value) + + @property + def kms_instance_url(self): + """URL of the KMS instance.""" + return frombytes(self.configuration.get().kms_instance_url) + + @kms_instance_url.setter + def kms_instance_url(self, value): + self.configuration.get().kms_instance_url = tobytes(value) + + @property + def key_access_token(self): + """Authorization token that will be passed to KMS.""" + return frombytes(self.configuration.get() + .refreshable_key_access_token.get().value()) + + @key_access_token.setter + def key_access_token(self, value): + self.refresh_key_access_token(value) + + @property + def custom_kms_conf(self): + """A dictionary with KMS-type-specific configuration""" + custom_kms_conf = { + frombytes(k): frombytes(v) + for k, v in self.configuration.get().custom_kms_conf + } + return custom_kms_conf + + @custom_kms_conf.setter + def custom_kms_conf(self, dict value): + if value is not None: + for k, v in value.items(): + if isinstance(k, str) and isinstance(v, str): + self.configuration.get().custom_kms_conf[tobytes(k)] = \ + tobytes(v) + else: + raise TypeError("Expected custom_kms_conf to be " + + "a dictionary of strings") + + def refresh_key_access_token(self, value): + cdef: + shared_ptr[CKeyAccessToken] c_key_access_token = \ + self.configuration.get().refreshable_key_access_token + + c_key_access_token.get().Refresh(tobytes(value)) + + cdef inline shared_ptr[CKmsConnectionConfig] unwrap(self) nogil: + return self.configuration + + @staticmethod + cdef wrap(const CKmsConnectionConfig& config): + result = KmsConnectionConfig() + result.configuration = make_shared[CKmsConnectionConfig](move(config)) + return result + + +# Callback definitions for CPyKmsClientVtable +cdef void _cb_wrap_key( + handler, const c_string& key_bytes, + const c_string& master_key_identifier, c_string* out) except *: + mkid_str = frombytes(master_key_identifier) + wrapped_key = handler.wrap_key(key_bytes, mkid_str) + out[0] = tobytes(wrapped_key) + + +cdef void _cb_unwrap_key( + handler, const c_string& wrapped_key, + const c_string& master_key_identifier, c_string* out) except *: + mkid_str = frombytes(master_key_identifier) + wk_str = frombytes(wrapped_key) + key = handler.unwrap_key(wk_str, mkid_str) + out[0] = tobytes(key) + + +cdef class KmsClient(_Weakrefable): + """The abstract base class for KmsClient implementations.""" + cdef: + shared_ptr[CKmsClient] client + + def __init__(self): + self.init() + + cdef init(self): + cdef: + CPyKmsClientVtable vtable = CPyKmsClientVtable() + + vtable.wrap_key = _cb_wrap_key + vtable.unwrap_key = _cb_unwrap_key + + self.client.reset(new CPyKmsClient(self, vtable)) + + def wrap_key(self, key_bytes, master_key_identifier): + """Wrap a key - encrypt it with the master key.""" + raise NotImplementedError() + + def unwrap_key(self, wrapped_key, master_key_identifier): + """Unwrap a key - decrypt it with the master key.""" + raise NotImplementedError() + + cdef inline shared_ptr[CKmsClient] unwrap(self) nogil: + return self.client + + +# Callback definition for CPyKmsClientFactoryVtable +cdef void _cb_create_kms_client( + handler, + const CKmsConnectionConfig& kms_connection_config, + shared_ptr[CKmsClient]* out) except *: + connection_config = KmsConnectionConfig.wrap(kms_connection_config) + + result = handler(connection_config) + if not isinstance(result, KmsClient): + raise TypeError( + "callable must return KmsClient instances, but got {}".format( + type(result))) + + out[0] = ( result).unwrap() + + +cdef class CryptoFactory(_Weakrefable): + """ A factory that produces the low-level FileEncryptionProperties and + FileDecryptionProperties objects, from the high-level parameters.""" + # Avoid mistakingly creating attributes + __slots__ = () + + def __init__(self, kms_client_factory): + """Create CryptoFactory. + + Parameters + ---------- + kms_client_factory : a callable that accepts KmsConnectionConfig + and returns a KmsClient + """ + self.factory.reset(new CPyCryptoFactory()) + + if callable(kms_client_factory): + self.init(kms_client_factory) + else: + raise TypeError("Parameter kms_client_factory must be a callable") + + cdef init(self, callable_client_factory): + cdef: + CPyKmsClientFactoryVtable vtable + shared_ptr[CPyKmsClientFactory] kms_client_factory + + vtable.create_kms_client = _cb_create_kms_client + kms_client_factory.reset( + new CPyKmsClientFactory(callable_client_factory, vtable)) + # A KmsClientFactory object must be registered + # via this method before calling any of + # file_encryption_properties()/file_decryption_properties() methods. + self.factory.get().RegisterKmsClientFactory( + static_pointer_cast[CKmsClientFactory, CPyKmsClientFactory]( + kms_client_factory)) + + def file_encryption_properties(self, + KmsConnectionConfig kms_connection_config, + EncryptionConfiguration encryption_config): + """Create file encryption properties. + + Parameters + ---------- + kms_connection_config : KmsConnectionConfig + Configuration of connection to KMS + + encryption_config : EncryptionConfiguration + Configuration of the encryption, such as which columns to encrypt + + Returns + ------- + file_encryption_properties : FileEncryptionProperties + File encryption properties. + """ + cdef: + CResult[shared_ptr[CFileEncryptionProperties]] \ + file_encryption_properties_result + with nogil: + file_encryption_properties_result = \ + self.factory.get().SafeGetFileEncryptionProperties( + deref(kms_connection_config.unwrap().get()), + deref(encryption_config.unwrap().get())) + file_encryption_properties = GetResultValue( + file_encryption_properties_result) + return FileEncryptionProperties.wrap(file_encryption_properties) + + def file_decryption_properties( + self, + KmsConnectionConfig kms_connection_config, + DecryptionConfiguration decryption_config=None): + """Create file decryption properties. + + Parameters + ---------- + kms_connection_config : KmsConnectionConfig + Configuration of connection to KMS + + decryption_config : DecryptionConfiguration, default None + Configuration of the decryption, such as cache timeout. + Can be None. + + Returns + ------- + file_decryption_properties : FileDecryptionProperties + File decryption properties. + """ + cdef: + CDecryptionConfiguration c_decryption_config + CResult[shared_ptr[CFileDecryptionProperties]] \ + c_file_decryption_properties + if decryption_config is None: + c_decryption_config = CDecryptionConfiguration() + else: + c_decryption_config = deref(decryption_config.unwrap().get()) + with nogil: + c_file_decryption_properties = \ + self.factory.get().SafeGetFileDecryptionProperties( + deref(kms_connection_config.unwrap().get()), + c_decryption_config) + file_decryption_properties = GetResultValue( + c_file_decryption_properties) + return FileDecryptionProperties.wrap(file_decryption_properties) + + def remove_cache_entries_for_token(self, access_token): + self.factory.get().RemoveCacheEntriesForToken(tobytes(access_token)) + + def remove_cache_entries_for_all_tokens(self): + self.factory.get().RemoveCacheEntriesForAllTokens() + + cdef inline shared_ptr[CPyCryptoFactory] unwrap(self): + return self.factory + + +cdef shared_ptr[CCryptoFactory] pyarrow_unwrap_cryptofactory(object crypto_factory) except *: + if isinstance(crypto_factory, CryptoFactory): + pycf = ( crypto_factory).unwrap() + return static_pointer_cast[CCryptoFactory, CPyCryptoFactory](pycf) + raise TypeError("Expected CryptoFactory, got %s" % type(crypto_factory)) + + +cdef shared_ptr[CKmsConnectionConfig] pyarrow_unwrap_kmsconnectionconfig(object kmsconnectionconfig) except *: + if isinstance(kmsconnectionconfig, KmsConnectionConfig): + return ( kmsconnectionconfig).unwrap() + raise TypeError("Expected KmsConnectionConfig, got %s" % type(kmsconnectionconfig)) + + +cdef shared_ptr[CEncryptionConfiguration] pyarrow_unwrap_encryptionconfig(object encryptionconfig) except *: + if isinstance(encryptionconfig, EncryptionConfiguration): + return ( encryptionconfig).unwrap() + raise TypeError("Expected EncryptionConfiguration, got %s" % type(encryptionconfig)) + + +cdef shared_ptr[CDecryptionConfiguration] pyarrow_unwrap_decryptionconfig(object decryptionconfig) except *: + if isinstance(decryptionconfig, DecryptionConfiguration): + return ( decryptionconfig).unwrap() + raise TypeError("Expected DecryptionConfiguration, got %s" % type(decryptionconfig)) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/_pyarrow_cpp_tests.cpython-312-x86_64-linux-gnu.so b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/_pyarrow_cpp_tests.cpython-312-x86_64-linux-gnu.so new file mode 100644 index 0000000000000000000000000000000000000000..5811bc35241bbb8023ec86d99106228cc7e5d5bc Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/_pyarrow_cpp_tests.cpython-312-x86_64-linux-gnu.so differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/_pyarrow_cpp_tests.pyx b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/_pyarrow_cpp_tests.pyx new file mode 100644 index 0000000000000000000000000000000000000000..adb148351306c02667346b3750c08f2efd8a6625 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/_pyarrow_cpp_tests.pyx @@ -0,0 +1,62 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# cython: profile=False, binding=True +# distutils: language = c++ + +from pyarrow.includes.common cimport * +from pyarrow.includes.libarrow cimport * +from pyarrow.lib cimport check_status + +from pyarrow.lib import frombytes + + +cdef class CppTestCase: + """ + A simple wrapper for a C++ test case. + """ + cdef: + CTestCase c_case + + @staticmethod + cdef wrap(CTestCase c_case): + cdef: + CppTestCase obj + obj = CppTestCase.__new__(CppTestCase) + obj.c_case = c_case + return obj + + @property + def name(self): + return frombytes(self.c_case.name) + + def __repr__(self): + return f"<{self.__class__.__name__} {self.name!r}>" + + def __call__(self): + check_status(self.c_case.func()) + + +def get_cpp_tests(): + """ + Get a list of C++ test cases. + """ + cases = [] + c_cases = GetCppTestCases() + for c_case in c_cases: + cases.append(CppTestCase.wrap(c_case)) + return cases diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/acero.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/acero.py new file mode 100644 index 0000000000000000000000000000000000000000..706338bd8cdb88e1fa9a1d5f301701d7116bc0f3 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/acero.py @@ -0,0 +1,410 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# --------------------------------------------------------------------- +# Implement Internal ExecPlan bindings + +# cython: profile=False +# distutils: language = c++ +# cython: language_level = 3 + +from pyarrow.lib import Table, RecordBatch +from pyarrow.compute import Expression, field + +try: + from pyarrow._acero import ( # noqa + Declaration, + ExecNodeOptions, + TableSourceNodeOptions, + FilterNodeOptions, + ProjectNodeOptions, + AggregateNodeOptions, + OrderByNodeOptions, + HashJoinNodeOptions, + AsofJoinNodeOptions, + ) +except ImportError as exc: + raise ImportError( + f"The pyarrow installation is not built with support for 'acero' ({str(exc)})" + ) from None + + +try: + import pyarrow.dataset as ds + from pyarrow._dataset import ScanNodeOptions +except ImportError: + class DatasetModuleStub: + class Dataset: + pass + + class InMemoryDataset: + pass + ds = DatasetModuleStub + + +def _dataset_to_decl(dataset, use_threads=True, require_sequenced_output=False): + decl = Declaration("scan", ScanNodeOptions( + dataset, use_threads=use_threads, + require_sequenced_output=require_sequenced_output)) + + # Get rid of special dataset columns + # "__fragment_index", "__batch_index", "__last_in_fragment", "__filename" + projections = [field(f) for f in dataset.schema.names] + decl = Declaration.from_sequence( + [decl, Declaration("project", ProjectNodeOptions(projections))] + ) + + filter_expr = dataset._scan_options.get("filter") + if filter_expr is not None: + # Filters applied in CScanNodeOptions are "best effort" for the scan node itself + # so we always need to inject an additional Filter node to apply them for real. + decl = Declaration.from_sequence( + [decl, Declaration("filter", FilterNodeOptions(filter_expr))] + ) + + return decl + + +def _perform_join(join_type, left_operand, left_keys, + right_operand, right_keys, + left_suffix=None, right_suffix=None, + use_threads=True, coalesce_keys=False, + output_type=Table): + """ + Perform join of two tables or datasets. + + The result will be an output table with the result of the join operation + + Parameters + ---------- + join_type : str + One of supported join types. + left_operand : Table or Dataset + The left operand for the join operation. + left_keys : str or list[str] + The left key (or keys) on which the join operation should be performed. + right_operand : Table or Dataset + The right operand for the join operation. + right_keys : str or list[str] + The right key (or keys) on which the join operation should be performed. + left_suffix : str, default None + Which suffix to add to left column names. This prevents confusion + when the columns in left and right operands have colliding names. + right_suffix : str, default None + Which suffix to add to the right column names. This prevents confusion + when the columns in left and right operands have colliding names. + use_threads : bool, default True + Whether to use multithreading or not. + coalesce_keys : bool, default False + If the duplicated keys should be omitted from one of the sides + in the join result. + output_type: Table or InMemoryDataset + The output type for the exec plan result. + + Returns + ------- + result_table : Table or InMemoryDataset + """ + if not isinstance(left_operand, (Table, ds.Dataset)): + raise TypeError(f"Expected Table or Dataset, got {type(left_operand)}") + if not isinstance(right_operand, (Table, ds.Dataset)): + raise TypeError(f"Expected Table or Dataset, got {type(right_operand)}") + + # Prepare left and right tables Keys to send them to the C++ function + left_keys_order = {} + if not isinstance(left_keys, (tuple, list)): + left_keys = [left_keys] + for idx, key in enumerate(left_keys): + left_keys_order[key] = idx + + right_keys_order = {} + if not isinstance(right_keys, (list, tuple)): + right_keys = [right_keys] + for idx, key in enumerate(right_keys): + right_keys_order[key] = idx + + # By default expose all columns on both left and right table + left_columns = left_operand.schema.names + right_columns = right_operand.schema.names + + # Pick the join type + if join_type == "left semi" or join_type == "left anti": + right_columns = [] + elif join_type == "right semi" or join_type == "right anti": + left_columns = [] + elif join_type == "inner" or join_type == "left outer": + right_columns = [ + col for col in right_columns if col not in right_keys_order + ] + elif join_type == "right outer": + left_columns = [ + col for col in left_columns if col not in left_keys_order + ] + + # Turn the columns to vectors of FieldRefs + # and set aside indices of keys. + left_column_keys_indices = {} + for idx, colname in enumerate(left_columns): + if colname in left_keys: + left_column_keys_indices[colname] = idx + right_column_keys_indices = {} + for idx, colname in enumerate(right_columns): + if colname in right_keys: + right_column_keys_indices[colname] = idx + + # Add the join node to the execplan + if isinstance(left_operand, ds.Dataset): + left_source = _dataset_to_decl(left_operand, use_threads=use_threads) + else: + left_source = Declaration("table_source", TableSourceNodeOptions(left_operand)) + if isinstance(right_operand, ds.Dataset): + right_source = _dataset_to_decl(right_operand, use_threads=use_threads) + else: + right_source = Declaration( + "table_source", TableSourceNodeOptions(right_operand) + ) + + if coalesce_keys: + join_opts = HashJoinNodeOptions( + join_type, left_keys, right_keys, left_columns, right_columns, + output_suffix_for_left=left_suffix or "", + output_suffix_for_right=right_suffix or "", + ) + else: + join_opts = HashJoinNodeOptions( + join_type, left_keys, right_keys, + output_suffix_for_left=left_suffix or "", + output_suffix_for_right=right_suffix or "", + ) + decl = Declaration( + "hashjoin", options=join_opts, inputs=[left_source, right_source] + ) + + if coalesce_keys and join_type == "full outer": + # In case of full outer joins, the join operation will output all columns + # so that we can coalesce the keys and exclude duplicates in a subsequent + # projection. + left_columns_set = set(left_columns) + right_columns_set = set(right_columns) + # Where the right table columns start. + right_operand_index = len(left_columns) + projected_col_names = [] + projections = [] + for idx, col in enumerate(left_columns + right_columns): + if idx < len(left_columns) and col in left_column_keys_indices: + # Include keys only once and coalesce left+right table keys. + projected_col_names.append(col) + # Get the index of the right key that is being paired + # with this left key. We do so by retrieving the name + # of the right key that is in the same position in the provided keys + # and then looking up the index for that name in the right table. + right_key_index = right_column_keys_indices[ + right_keys[left_keys_order[col]]] + projections.append( + Expression._call("coalesce", [ + Expression._field(idx), Expression._field( + right_operand_index+right_key_index) + ]) + ) + elif idx >= right_operand_index and col in right_column_keys_indices: + # Do not include right table keys. As they would lead to duplicated keys + continue + else: + # For all the other columns include them as they are. + # Just recompute the suffixes that the join produced as the projection + # would lose them otherwise. + if ( + left_suffix and idx < right_operand_index + and col in right_columns_set + ): + col += left_suffix + if ( + right_suffix and idx >= right_operand_index + and col in left_columns_set + ): + col += right_suffix + projected_col_names.append(col) + projections.append( + Expression._field(idx) + ) + projection = Declaration( + "project", ProjectNodeOptions(projections, projected_col_names) + ) + decl = Declaration.from_sequence([decl, projection]) + + result_table = decl.to_table(use_threads=use_threads) + + if output_type == Table: + return result_table + elif output_type == ds.InMemoryDataset: + return ds.InMemoryDataset(result_table) + else: + raise TypeError("Unsupported output type") + + +def _perform_join_asof(left_operand, left_on, left_by, + right_operand, right_on, right_by, + tolerance, use_threads=True, + output_type=Table): + """ + Perform asof join of two tables or datasets. + + The result will be an output table with the result of the join operation + + Parameters + ---------- + left_operand : Table or Dataset + The left operand for the join operation. + left_on : str + The left key (or keys) on which the join operation should be performed. + left_by: str or list[str] + The left key (or keys) on which the join operation should be performed. + right_operand : Table or Dataset + The right operand for the join operation. + right_on : str or list[str] + The right key (or keys) on which the join operation should be performed. + right_by: str or list[str] + The right key (or keys) on which the join operation should be performed. + tolerance : int + The tolerance to use for the asof join. The tolerance is interpreted in + the same units as the "on" key. + output_type: Table or InMemoryDataset + The output type for the exec plan result. + + Returns + ------- + result_table : Table or InMemoryDataset + """ + if not isinstance(left_operand, (Table, ds.Dataset)): + raise TypeError(f"Expected Table or Dataset, got {type(left_operand)}") + if not isinstance(right_operand, (Table, ds.Dataset)): + raise TypeError(f"Expected Table or Dataset, got {type(right_operand)}") + + if not isinstance(left_by, (tuple, list)): + left_by = [left_by] + if not isinstance(right_by, (tuple, list)): + right_by = [right_by] + + # AsofJoin does not return on or by columns for right_operand. + right_columns = [ + col for col in right_operand.schema.names + if col not in [right_on] + right_by + ] + columns_collisions = set(left_operand.schema.names) & set(right_columns) + if columns_collisions: + raise ValueError( + "Columns {} present in both tables. AsofJoin does not support " + "column collisions.".format(columns_collisions), + ) + + # Add the join node to the execplan + if isinstance(left_operand, ds.Dataset): + left_source = _dataset_to_decl( + left_operand, + use_threads=use_threads, + require_sequenced_output=True) + else: + left_source = Declaration( + "table_source", TableSourceNodeOptions(left_operand), + ) + if isinstance(right_operand, ds.Dataset): + right_source = _dataset_to_decl( + right_operand, use_threads=use_threads, + require_sequenced_output=True) + else: + right_source = Declaration( + "table_source", TableSourceNodeOptions(right_operand) + ) + + join_opts = AsofJoinNodeOptions( + left_on, left_by, right_on, right_by, tolerance + ) + decl = Declaration( + "asofjoin", options=join_opts, inputs=[left_source, right_source] + ) + + result_table = decl.to_table(use_threads=use_threads) + + if output_type == Table: + return result_table + elif output_type == ds.InMemoryDataset: + return ds.InMemoryDataset(result_table) + else: + raise TypeError("Unsupported output type") + + +def _filter_table(table, expression): + """Filter rows of a table based on the provided expression. + + The result will be an output table with only the rows matching + the provided expression. + + Parameters + ---------- + table : Table or RecordBatch + Table that should be filtered. + expression : Expression + The expression on which rows should be filtered. + + Returns + ------- + Table + """ + is_batch = False + if isinstance(table, RecordBatch): + table = Table.from_batches([table]) + is_batch = True + + decl = Declaration.from_sequence([ + Declaration("table_source", options=TableSourceNodeOptions(table)), + Declaration("filter", options=FilterNodeOptions(expression)) + ]) + result = decl.to_table(use_threads=True) + if is_batch: + result = result.combine_chunks().to_batches()[0] + return result + + +def _sort_source(table_or_dataset, sort_keys, output_type=Table, **kwargs): + + if isinstance(table_or_dataset, ds.Dataset): + data_source = _dataset_to_decl(table_or_dataset, use_threads=True) + else: + data_source = Declaration( + "table_source", TableSourceNodeOptions(table_or_dataset) + ) + + order_by = Declaration("order_by", OrderByNodeOptions(sort_keys, **kwargs)) + + decl = Declaration.from_sequence([data_source, order_by]) + result_table = decl.to_table(use_threads=True) + + if output_type == Table: + return result_table + elif output_type == ds.InMemoryDataset: + return ds.InMemoryDataset(result_table) + else: + raise TypeError("Unsupported output type") + + +def _group_by(table, aggregates, keys, use_threads=True): + + decl = Declaration.from_sequence([ + Declaration("table_source", TableSourceNodeOptions(table)), + Declaration("aggregate", AggregateNodeOptions(aggregates, keys=keys)) + ]) + return decl.to_table(use_threads=use_threads) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/array.pxi b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/array.pxi new file mode 100644 index 0000000000000000000000000000000000000000..2ef42051d9ad250bb07a1ec1584ac67211a066bc --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/array.pxi @@ -0,0 +1,4837 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from cpython.pycapsule cimport PyCapsule_CheckExact, PyCapsule_GetPointer, PyCapsule_New + +import os +import warnings +from cython import sizeof + + +cdef _sequence_to_array(object sequence, object mask, object size, + DataType type, CMemoryPool* pool, c_bool from_pandas): + cdef: + int64_t c_size + PyConversionOptions options + shared_ptr[CChunkedArray] chunked + + if type is not None: + options.type = type.sp_type + + if size is not None: + options.size = size + + options.from_pandas = from_pandas + options.ignore_timezone = os.environ.get('PYARROW_IGNORE_TIMEZONE', False) + + with nogil: + chunked = GetResultValue( + ConvertPySequence(sequence, mask, options, pool) + ) + + if chunked.get().num_chunks() == 1: + return pyarrow_wrap_array(chunked.get().chunk(0)) + else: + return pyarrow_wrap_chunked_array(chunked) + + +cdef inline _is_array_like(obj): + if np is None: + return False + if isinstance(obj, np.ndarray): + return True + return pandas_api._have_pandas_internal() and pandas_api.is_array_like(obj) + + +def _ndarray_to_arrow_type(object values, DataType type): + return pyarrow_wrap_data_type(_ndarray_to_type(values, type)) + + +cdef shared_ptr[CDataType] _ndarray_to_type(object values, + DataType type) except *: + cdef shared_ptr[CDataType] c_type + + dtype = values.dtype + + if type is None and dtype != object: + c_type = GetResultValue(NumPyDtypeToArrow(dtype)) + + if type is not None: + c_type = type.sp_type + + return c_type + + +cdef _ndarray_to_array(object values, object mask, DataType type, + c_bool from_pandas, c_bool safe, CMemoryPool* pool): + cdef: + shared_ptr[CChunkedArray] chunked_out + shared_ptr[CDataType] c_type = _ndarray_to_type(values, type) + CCastOptions cast_options = CCastOptions(safe) + + with nogil: + check_status(NdarrayToArrow(pool, values, mask, from_pandas, + c_type, cast_options, &chunked_out)) + + if chunked_out.get().num_chunks() > 1: + return pyarrow_wrap_chunked_array(chunked_out) + else: + return pyarrow_wrap_array(chunked_out.get().chunk(0)) + + +cdef _codes_to_indices(object codes, object mask, DataType type, + MemoryPool memory_pool): + """ + Convert the codes of a pandas Categorical to indices for a pyarrow + DictionaryArray, taking into account missing values + mask + """ + if mask is None: + mask = codes == -1 + else: + mask = mask | (codes == -1) + return array(codes, mask=mask, type=type, memory_pool=memory_pool) + + +def _handle_arrow_array_protocol(obj, type, mask, size): + if mask is not None or size is not None: + raise ValueError( + "Cannot specify a mask or a size when passing an object that is " + "converted with the __arrow_array__ protocol.") + res = obj.__arrow_array__(type=type) + if not isinstance(res, (Array, ChunkedArray)): + raise TypeError("The object's __arrow_array__ method does not " + "return a pyarrow Array or ChunkedArray.") + if isinstance(res, ChunkedArray) and res.num_chunks==1: + res = res.chunk(0) + if type is not None and res.type != type: + res = res.cast(type) + return res + + +def array(object obj, type=None, mask=None, size=None, from_pandas=None, + bint safe=True, MemoryPool memory_pool=None): + """ + Create pyarrow.Array instance from a Python object. + + Parameters + ---------- + obj : sequence, iterable, ndarray, pandas.Series, Arrow-compatible array + If both type and size are specified may be a single use iterable. If + not strongly-typed, Arrow type will be inferred for resulting array. + Any Arrow-compatible array that implements the Arrow PyCapsule Protocol + (has an ``__arrow_c_array__`` or ``__arrow_c_device_array__`` method) + can be passed as well. + type : pyarrow.DataType + Explicit type to attempt to coerce to, otherwise will be inferred from + the data. + mask : array[bool], optional + Indicate which values are null (True) or not null (False). + size : int64, optional + Size of the elements. If the input is larger than size bail at this + length. For iterators, if size is larger than the input iterator this + will be treated as a "max size", but will involve an initial allocation + of size followed by a resize to the actual size (so if you know the + exact size specifying it correctly will give you better performance). + from_pandas : bool, default None + Use pandas's semantics for inferring nulls from values in + ndarray-like data. If passed, the mask tasks precedence, but + if a value is unmasked (not-null), but still null according to + pandas semantics, then it is null. Defaults to False if not + passed explicitly by user, or True if a pandas object is + passed in. + safe : bool, default True + Check for overflows or other unsafe conversions. + memory_pool : pyarrow.MemoryPool, optional + If not passed, will allocate memory from the currently-set default + memory pool. + + Returns + ------- + array : pyarrow.Array or pyarrow.ChunkedArray + A ChunkedArray instead of an Array is returned if: + + - the object data overflowed binary storage. + - the object's ``__arrow_array__`` protocol method returned a chunked + array. + + Notes + ----- + Timezone will be preserved in the returned array for timezone-aware data, + else no timezone will be returned for naive timestamps. + Internally, UTC values are stored for timezone-aware data with the + timezone set in the data type. + + Pandas's DateOffsets and dateutil.relativedelta.relativedelta are by + default converted as MonthDayNanoIntervalArray. relativedelta leapdays + are ignored as are all absolute fields on both objects. datetime.timedelta + can also be converted to MonthDayNanoIntervalArray but this requires + passing MonthDayNanoIntervalType explicitly. + + Converting to dictionary array will promote to a wider integer type for + indices if the number of distinct values cannot be represented, even if + the index type was explicitly set. This means that if there are more than + 127 values the returned dictionary array's index type will be at least + pa.int16() even if pa.int8() was passed to the function. Note that an + explicit index type will not be demoted even if it is wider than required. + + Examples + -------- + >>> import pandas as pd + >>> import pyarrow as pa + >>> pa.array(pd.Series([1, 2])) + + [ + 1, + 2 + ] + + >>> pa.array(["a", "b", "a"], type=pa.dictionary(pa.int8(), pa.string())) + + ... + -- dictionary: + [ + "a", + "b" + ] + -- indices: + [ + 0, + 1, + 0 + ] + + >>> import numpy as np + >>> pa.array(pd.Series([1, 2]), mask=np.array([0, 1], dtype=bool)) + + [ + 1, + null + ] + + >>> arr = pa.array(range(1024), type=pa.dictionary(pa.int8(), pa.int64())) + >>> arr.type.index_type + DataType(int16) + """ + cdef: + CMemoryPool* pool = maybe_unbox_memory_pool(memory_pool) + bint is_pandas_object = False + bint c_from_pandas + + type = ensure_type(type, allow_none=True) + + extension_type = None + if type is not None and type.id == _Type_EXTENSION: + extension_type = type + type = type.storage_type + + if from_pandas is None: + c_from_pandas = False + else: + c_from_pandas = from_pandas + + if isinstance(obj, Array): + if type is not None and not obj.type.equals(type): + obj = obj.cast(type, safe=safe, memory_pool=memory_pool) + return obj + + if hasattr(obj, '__arrow_array__'): + return _handle_arrow_array_protocol(obj, type, mask, size) + elif hasattr(obj, '__arrow_c_device_array__'): + if type is not None: + requested_type = type.__arrow_c_schema__() + else: + requested_type = None + schema_capsule, array_capsule = obj.__arrow_c_device_array__(requested_type) + out_array = Array._import_from_c_device_capsule(schema_capsule, array_capsule) + if type is not None and out_array.type != type: + # PyCapsule interface type coercion is best effort, so we need to + # check the type of the returned array and cast if necessary + out_array = array.cast(type, safe=safe, memory_pool=memory_pool) + return out_array + elif hasattr(obj, '__arrow_c_array__'): + if type is not None: + requested_type = type.__arrow_c_schema__() + else: + requested_type = None + schema_capsule, array_capsule = obj.__arrow_c_array__(requested_type) + out_array = Array._import_from_c_capsule(schema_capsule, array_capsule) + if type is not None and out_array.type != type: + # PyCapsule interface type coercion is best effort, so we need to + # check the type of the returned array and cast if necessary + out_array = array.cast(type, safe=safe, memory_pool=memory_pool) + return out_array + elif _is_array_like(obj): + if mask is not None: + if _is_array_like(mask): + mask = get_values(mask, &is_pandas_object) + else: + raise TypeError("Mask must be a numpy array " + "when converting numpy arrays") + + values = get_values(obj, &is_pandas_object) + if is_pandas_object and from_pandas is None: + c_from_pandas = True + + if isinstance(values, np.ma.MaskedArray): + if mask is not None: + raise ValueError("Cannot pass a numpy masked array and " + "specify a mask at the same time") + else: + # don't use shrunken masks + mask = None if values.mask is np.ma.nomask else values.mask + values = values.data + + if mask is not None: + if mask.dtype != np.bool_: + raise TypeError("Mask must be boolean dtype") + if mask.ndim != 1: + raise ValueError("Mask must be 1D array") + if len(values) != len(mask): + raise ValueError( + "Mask is a different length from sequence being converted") + + if hasattr(values, '__arrow_array__'): + return _handle_arrow_array_protocol(values, type, mask, size) + elif (pandas_api.is_categorical(values) and + type is not None and type.id != Type_DICTIONARY): + result = _ndarray_to_array( + np.asarray(values), mask, type, c_from_pandas, safe, pool + ) + elif pandas_api.is_categorical(values): + if type is not None: + index_type = type.index_type + value_type = type.value_type + if values.ordered != type.ordered: + raise ValueError( + "The 'ordered' flag of the passed categorical values " + "does not match the 'ordered' of the specified type. ") + else: + index_type = None + value_type = None + + indices = _codes_to_indices( + values.codes, mask, index_type, memory_pool) + try: + dictionary = array( + values.categories.values, type=value_type, + memory_pool=memory_pool) + except TypeError: + # TODO when removing the deprecation warning, this whole + # try/except can be removed (to bubble the TypeError of + # the first array(..) call) + if value_type is not None: + warnings.warn( + "The dtype of the 'categories' of the passed " + "categorical values ({0}) does not match the " + "specified type ({1}). For now ignoring the specified " + "type, but in the future this mismatch will raise a " + "TypeError".format( + values.categories.dtype, value_type), + FutureWarning, stacklevel=2) + dictionary = array( + values.categories.values, memory_pool=memory_pool) + else: + raise + + return DictionaryArray.from_arrays( + indices, dictionary, ordered=values.ordered, safe=safe) + else: + if pandas_api.have_pandas: + values, type = pandas_api.compat.get_datetimetz_type( + values, obj.dtype, type) + if type and type.id == _Type_RUN_END_ENCODED: + arr = _ndarray_to_array( + values, mask, type.value_type, c_from_pandas, safe, pool) + result = _pc().run_end_encode(arr, run_end_type=type.run_end_type, + memory_pool=memory_pool) + else: + result = _ndarray_to_array(values, mask, type, c_from_pandas, safe, + pool) + else: + if type and type.id == _Type_RUN_END_ENCODED: + arr = _sequence_to_array( + obj, mask, size, type.value_type, pool, from_pandas) + result = _pc().run_end_encode(arr, run_end_type=type.run_end_type, + memory_pool=memory_pool) + # ConvertPySequence does strict conversion if type is explicitly passed + else: + result = _sequence_to_array(obj, mask, size, type, pool, c_from_pandas) + + if extension_type is not None: + result = ExtensionArray.from_storage(extension_type, result) + return result + + +def asarray(values, type=None): + """ + Convert to pyarrow.Array, inferring type if not provided. + + Parameters + ---------- + values : array-like + This can be a sequence, numpy.ndarray, pyarrow.Array or + pyarrow.ChunkedArray. If a ChunkedArray is passed, the output will be + a ChunkedArray, otherwise the output will be a Array. + type : string or DataType + Explicitly construct the array with this type. Attempt to cast if + indicated type is different. + + Returns + ------- + arr : Array or ChunkedArray + """ + if isinstance(values, (Array, ChunkedArray)): + if type is not None and not values.type.equals(type): + values = values.cast(type) + return values + else: + return array(values, type=type) + + +def nulls(size, type=None, MemoryPool memory_pool=None): + """ + Create a strongly-typed Array instance with all elements null. + + Parameters + ---------- + size : int + Array length. + type : pyarrow.DataType, default None + Explicit type for the array. By default use NullType. + memory_pool : MemoryPool, default None + Arrow MemoryPool to use for allocations. Uses the default memory + pool if not passed. + + Returns + ------- + arr : Array + + Examples + -------- + >>> import pyarrow as pa + >>> pa.nulls(10) + + 10 nulls + + >>> pa.nulls(3, pa.uint32()) + + [ + null, + null, + null + ] + """ + cdef: + CMemoryPool* pool = maybe_unbox_memory_pool(memory_pool) + int64_t length = size + shared_ptr[CDataType] ty + shared_ptr[CArray] arr + + type = ensure_type(type, allow_none=True) + if type is None: + type = null() + + ty = pyarrow_unwrap_data_type(type) + with nogil: + arr = GetResultValue(MakeArrayOfNull(ty, length, pool)) + + return pyarrow_wrap_array(arr) + + +def repeat(value, size, MemoryPool memory_pool=None): + """ + Create an Array instance whose slots are the given scalar. + + Parameters + ---------- + value : Scalar-like object + Either a pyarrow.Scalar or any python object coercible to a Scalar. + size : int + Number of times to repeat the scalar in the output Array. + memory_pool : MemoryPool, default None + Arrow MemoryPool to use for allocations. Uses the default memory + pool if not passed. + + Returns + ------- + arr : Array + + Examples + -------- + >>> import pyarrow as pa + >>> pa.repeat(10, 3) + + [ + 10, + 10, + 10 + ] + + >>> pa.repeat([1, 2], 2) + + [ + [ + 1, + 2 + ], + [ + 1, + 2 + ] + ] + + >>> pa.repeat("string", 3) + + [ + "string", + "string", + "string" + ] + + >>> pa.repeat(pa.scalar({'a': 1, 'b': [1, 2]}), 2) + + -- is_valid: all not null + -- child 0 type: int64 + [ + 1, + 1 + ] + -- child 1 type: list + [ + [ + 1, + 2 + ], + [ + 1, + 2 + ] + ] + """ + cdef: + CMemoryPool* pool = maybe_unbox_memory_pool(memory_pool) + int64_t length = size + shared_ptr[CArray] c_array + shared_ptr[CScalar] c_scalar + + if not isinstance(value, Scalar): + value = scalar(value, memory_pool=memory_pool) + + c_scalar = ( value).unwrap() + with nogil: + c_array = GetResultValue( + MakeArrayFromScalar(deref(c_scalar), length, pool) + ) + + return pyarrow_wrap_array(c_array) + + +def infer_type(values, mask=None, from_pandas=False): + """ + Attempt to infer Arrow data type that can hold the passed Python + sequence type in an Array object + + Parameters + ---------- + values : array-like + Sequence to infer type from. + mask : ndarray (bool type), optional + Optional exclusion mask where True marks null, False non-null. + from_pandas : bool, default False + Use pandas's NA/null sentinel values for type inference. + + Returns + ------- + type : DataType + """ + cdef: + shared_ptr[CDataType] out + c_bool use_pandas_sentinels = from_pandas + + if mask is not None and not isinstance(mask, np.ndarray): + mask = np.array(mask, dtype=bool) + + out = GetResultValue(InferArrowType(values, mask, use_pandas_sentinels)) + return pyarrow_wrap_data_type(out) + + +def _normalize_slice(object arrow_obj, slice key): + """ + Slices with step not equal to 1 (or None) will produce a copy + rather than a zero-copy view + """ + cdef: + Py_ssize_t start, stop, step + Py_ssize_t n = len(arrow_obj) + + start, stop, step = key.indices(n) + + if step != 1: + indices = np.arange(start, stop, step) + return arrow_obj.take(indices) + else: + length = max(stop - start, 0) + return arrow_obj.slice(start, length) + + +cdef Py_ssize_t _normalize_index(Py_ssize_t index, + Py_ssize_t length) except -1: + if index < 0: + index += length + if index < 0: + raise IndexError("index out of bounds") + elif index >= length: + raise IndexError("index out of bounds") + return index + + +cdef wrap_datum(const CDatum& datum): + if datum.kind() == DatumType_ARRAY: + return pyarrow_wrap_array(MakeArray(datum.array())) + elif datum.kind() == DatumType_CHUNKED_ARRAY: + return pyarrow_wrap_chunked_array(datum.chunked_array()) + elif datum.kind() == DatumType_RECORD_BATCH: + return pyarrow_wrap_batch(datum.record_batch()) + elif datum.kind() == DatumType_TABLE: + return pyarrow_wrap_table(datum.table()) + elif datum.kind() == DatumType_SCALAR: + return pyarrow_wrap_scalar(datum.scalar()) + else: + raise ValueError("Unable to wrap Datum in a Python object") + + +cdef _append_array_buffers(const CArrayData* ad, list res): + """ + Recursively append Buffer wrappers from *ad* and its children. + """ + cdef size_t i, n + assert ad != NULL + n = ad.buffers.size() + for i in range(n): + buf = ad.buffers[i] + res.append(pyarrow_wrap_buffer(buf) + if buf.get() != NULL else None) + n = ad.child_data.size() + for i in range(n): + _append_array_buffers(ad.child_data[i].get(), res) + + +cdef _reduce_array_data(const CArrayData* ad): + """ + Recursively dissect ArrayData to (pickable) tuples. + """ + cdef size_t i, n + assert ad != NULL + + n = ad.buffers.size() + buffers = [] + for i in range(n): + buf = ad.buffers[i] + buffers.append(pyarrow_wrap_buffer(buf) + if buf.get() != NULL else None) + + children = [] + n = ad.child_data.size() + for i in range(n): + children.append(_reduce_array_data(ad.child_data[i].get())) + + if ad.dictionary.get() != NULL: + dictionary = _reduce_array_data(ad.dictionary.get()) + else: + dictionary = None + + return pyarrow_wrap_data_type(ad.type), ad.length, ad.null_count, \ + ad.offset, buffers, children, dictionary + + +cdef shared_ptr[CArrayData] _reconstruct_array_data(data): + """ + Reconstruct CArrayData objects from the tuple structure generated + by _reduce_array_data. + """ + cdef: + int64_t length, null_count, offset, i + DataType dtype + Buffer buf + vector[shared_ptr[CBuffer]] c_buffers + vector[shared_ptr[CArrayData]] c_children + shared_ptr[CArrayData] c_dictionary + + dtype, length, null_count, offset, buffers, children, dictionary = data + + for i in range(len(buffers)): + buf = buffers[i] + if buf is None: + c_buffers.push_back(shared_ptr[CBuffer]()) + else: + c_buffers.push_back(buf.buffer) + + for i in range(len(children)): + c_children.push_back(_reconstruct_array_data(children[i])) + + if dictionary is not None: + c_dictionary = _reconstruct_array_data(dictionary) + + return CArrayData.MakeWithChildrenAndDictionary( + dtype.sp_type, + length, + c_buffers, + c_children, + c_dictionary, + null_count, + offset) + + +def _restore_array(data): + """ + Reconstruct an Array from pickled ArrayData. + """ + cdef shared_ptr[CArrayData] ad = _reconstruct_array_data(data) + return pyarrow_wrap_array(MakeArray(ad)) + + +cdef class _PandasConvertible(_Weakrefable): + + def to_pandas( + self, + memory_pool=None, + categories=None, + bint strings_to_categorical=False, + bint zero_copy_only=False, + bint integer_object_nulls=False, + bint date_as_object=True, + bint timestamp_as_object=False, + bint use_threads=True, + bint deduplicate_objects=True, + bint ignore_metadata=False, + bint safe=True, + bint split_blocks=False, + bint self_destruct=False, + str maps_as_pydicts=None, + types_mapper=None, + bint coerce_temporal_nanoseconds=False + ): + """ + Convert to a pandas-compatible NumPy array or DataFrame, as appropriate + + Parameters + ---------- + memory_pool : MemoryPool, default None + Arrow MemoryPool to use for allocations. Uses the default memory + pool if not passed. + categories : list, default empty + List of fields that should be returned as pandas.Categorical. Only + applies to table-like data structures. + strings_to_categorical : bool, default False + Encode string (UTF8) and binary types to pandas.Categorical. + zero_copy_only : bool, default False + Raise an ArrowException if this function call would require copying + the underlying data. + integer_object_nulls : bool, default False + Cast integers with nulls to objects + date_as_object : bool, default True + Cast dates to objects. If False, convert to datetime64 dtype with + the equivalent time unit (if supported). Note: in pandas version + < 2.0, only datetime64[ns] conversion is supported. + timestamp_as_object : bool, default False + Cast non-nanosecond timestamps (np.datetime64) to objects. This is + useful in pandas version 1.x if you have timestamps that don't fit + in the normal date range of nanosecond timestamps (1678 CE-2262 CE). + Non-nanosecond timestamps are supported in pandas version 2.0. + If False, all timestamps are converted to datetime64 dtype. + use_threads : bool, default True + Whether to parallelize the conversion using multiple threads. + deduplicate_objects : bool, default True + Do not create multiple copies Python objects when created, to save + on memory use. Conversion will be slower. + ignore_metadata : bool, default False + If True, do not use the 'pandas' metadata to reconstruct the + DataFrame index, if present + safe : bool, default True + For certain data types, a cast is needed in order to store the + data in a pandas DataFrame or Series (e.g. timestamps are always + stored as nanoseconds in pandas). This option controls whether it + is a safe cast or not. + split_blocks : bool, default False + If True, generate one internal "block" for each column when + creating a pandas.DataFrame from a RecordBatch or Table. While this + can temporarily reduce memory note that various pandas operations + can trigger "consolidation" which may balloon memory use. + self_destruct : bool, default False + EXPERIMENTAL: If True, attempt to deallocate the originating Arrow + memory while converting the Arrow object to pandas. If you use the + object after calling to_pandas with this option it will crash your + program. + + Note that you may not see always memory usage improvements. For + example, if multiple columns share an underlying allocation, + memory can't be freed until all columns are converted. + maps_as_pydicts : str, optional, default `None` + Valid values are `None`, 'lossy', or 'strict'. + The default behavior (`None`), is to convert Arrow Map arrays to + Python association lists (list-of-tuples) in the same order as the + Arrow Map, as in [(key1, value1), (key2, value2), ...]. + + If 'lossy' or 'strict', convert Arrow Map arrays to native Python dicts. + This can change the ordering of (key, value) pairs, and will + deduplicate multiple keys, resulting in a possible loss of data. + + If 'lossy', this key deduplication results in a warning printed + when detected. If 'strict', this instead results in an exception + being raised when detected. + types_mapper : function, default None + A function mapping a pyarrow DataType to a pandas ExtensionDtype. + This can be used to override the default pandas type for conversion + of built-in pyarrow types or in absence of pandas_metadata in the + Table schema. The function receives a pyarrow DataType and is + expected to return a pandas ExtensionDtype or ``None`` if the + default conversion should be used for that type. If you have + a dictionary mapping, you can pass ``dict.get`` as function. + coerce_temporal_nanoseconds : bool, default False + Only applicable to pandas version >= 2.0. + A legacy option to coerce date32, date64, duration, and timestamp + time units to nanoseconds when converting to pandas. This is the + default behavior in pandas version 1.x. Set this option to True if + you'd like to use this coercion when using pandas version >= 2.0 + for backwards compatibility (not recommended otherwise). + + Returns + ------- + pandas.Series or pandas.DataFrame depending on type of object + + Examples + -------- + >>> import pyarrow as pa + >>> import pandas as pd + + Convert a Table to pandas DataFrame: + + >>> table = pa.table([ + ... pa.array([2, 4, 5, 100]), + ... pa.array(["Flamingo", "Horse", "Brittle stars", "Centipede"]) + ... ], names=['n_legs', 'animals']) + >>> table.to_pandas() + n_legs animals + 0 2 Flamingo + 1 4 Horse + 2 5 Brittle stars + 3 100 Centipede + >>> isinstance(table.to_pandas(), pd.DataFrame) + True + + Convert a RecordBatch to pandas DataFrame: + + >>> import pyarrow as pa + >>> n_legs = pa.array([2, 4, 5, 100]) + >>> animals = pa.array(["Flamingo", "Horse", "Brittle stars", "Centipede"]) + >>> batch = pa.record_batch([n_legs, animals], + ... names=["n_legs", "animals"]) + >>> batch + pyarrow.RecordBatch + n_legs: int64 + animals: string + ---- + n_legs: [2,4,5,100] + animals: ["Flamingo","Horse","Brittle stars","Centipede"] + >>> batch.to_pandas() + n_legs animals + 0 2 Flamingo + 1 4 Horse + 2 5 Brittle stars + 3 100 Centipede + >>> isinstance(batch.to_pandas(), pd.DataFrame) + True + + Convert a Chunked Array to pandas Series: + + >>> import pyarrow as pa + >>> n_legs = pa.chunked_array([[2, 2, 4], [4, 5, 100]]) + >>> n_legs.to_pandas() + 0 2 + 1 2 + 2 4 + 3 4 + 4 5 + 5 100 + dtype: int64 + >>> isinstance(n_legs.to_pandas(), pd.Series) + True + """ + options = dict( + pool=memory_pool, + strings_to_categorical=strings_to_categorical, + zero_copy_only=zero_copy_only, + integer_object_nulls=integer_object_nulls, + date_as_object=date_as_object, + timestamp_as_object=timestamp_as_object, + use_threads=use_threads, + deduplicate_objects=deduplicate_objects, + safe=safe, + split_blocks=split_blocks, + self_destruct=self_destruct, + maps_as_pydicts=maps_as_pydicts, + coerce_temporal_nanoseconds=coerce_temporal_nanoseconds + ) + return self._to_pandas(options, categories=categories, + ignore_metadata=ignore_metadata, + types_mapper=types_mapper) + + +cdef PandasOptions _convert_pandas_options(dict options): + cdef PandasOptions result + result.pool = maybe_unbox_memory_pool(options['pool']) + result.strings_to_categorical = options['strings_to_categorical'] + result.zero_copy_only = options['zero_copy_only'] + result.integer_object_nulls = options['integer_object_nulls'] + result.date_as_object = options['date_as_object'] + result.timestamp_as_object = options['timestamp_as_object'] + result.use_threads = options['use_threads'] + result.deduplicate_objects = options['deduplicate_objects'] + result.safe_cast = options['safe'] + result.split_blocks = options['split_blocks'] + result.self_destruct = options['self_destruct'] + result.coerce_temporal_nanoseconds = options['coerce_temporal_nanoseconds'] + result.ignore_timezone = os.environ.get('PYARROW_IGNORE_TIMEZONE', False) + + maps_as_pydicts = options['maps_as_pydicts'] + if maps_as_pydicts is None: + result.maps_as_pydicts = MapConversionType.DEFAULT + elif maps_as_pydicts == "lossy": + result.maps_as_pydicts = MapConversionType.LOSSY + elif maps_as_pydicts == "strict": + result.maps_as_pydicts = MapConversionType.STRICT_ + else: + raise ValueError( + "Invalid value for 'maps_as_pydicts': " + + "valid values are 'lossy', 'strict' or `None` (default). " + + f"Received '{maps_as_pydicts}'." + ) + return result + + +cdef class Array(_PandasConvertible): + """ + The base class for all Arrow arrays. + """ + + def __init__(self): + raise TypeError("Do not call {}'s constructor directly, use one of " + "the `pyarrow.Array.from_*` functions instead." + .format(self.__class__.__name__)) + + cdef void init(self, const shared_ptr[CArray]& sp_array) except *: + self.sp_array = sp_array + self.ap = sp_array.get() + self.type = pyarrow_wrap_data_type(self.sp_array.get().type()) + + def _debug_print(self): + with nogil: + check_status(DebugPrint(deref(self.ap), 0)) + + def diff(self, Array other): + """ + Compare contents of this array against another one. + + Return a string containing the result of diffing this array + (on the left side) against the other array (on the right side). + + Parameters + ---------- + other : Array + The other array to compare this array with. + + Returns + ------- + diff : str + A human-readable printout of the differences. + + Examples + -------- + >>> import pyarrow as pa + >>> left = pa.array(["one", "two", "three"]) + >>> right = pa.array(["two", None, "two-and-a-half", "three"]) + >>> print(left.diff(right)) # doctest: +SKIP + + @@ -0, +0 @@ + -"one" + @@ -2, +1 @@ + +null + +"two-and-a-half" + + """ + self._assert_cpu() + cdef c_string result + with nogil: + result = self.ap.Diff(deref(other.ap)) + return frombytes(result, safe=True) + + def cast(self, object target_type=None, safe=None, options=None, memory_pool=None): + """ + Cast array values to another data type + + See :func:`pyarrow.compute.cast` for usage. + + Parameters + ---------- + target_type : DataType, default None + Type to cast array to. + safe : boolean, default True + Whether to check for conversion errors such as overflow. + options : CastOptions, default None + Additional checks pass by CastOptions + memory_pool : MemoryPool, optional + memory pool to use for allocations during function execution. + + Returns + ------- + cast : Array + """ + self._assert_cpu() + return _pc().cast(self, target_type, safe=safe, + options=options, memory_pool=memory_pool) + + def view(self, object target_type): + """ + Return zero-copy "view" of array as another data type. + + The data types must have compatible columnar buffer layouts + + Parameters + ---------- + target_type : DataType + Type to construct view as. + + Returns + ------- + view : Array + """ + self._assert_cpu() + cdef DataType type = ensure_type(target_type) + cdef shared_ptr[CArray] result + with nogil: + result = GetResultValue(self.ap.View(type.sp_type)) + return pyarrow_wrap_array(result) + + def sum(self, **kwargs): + """ + Sum the values in a numerical array. + + See :func:`pyarrow.compute.sum` for full usage. + + Parameters + ---------- + **kwargs : dict, optional + Options to pass to :func:`pyarrow.compute.sum`. + + Returns + ------- + sum : Scalar + A scalar containing the sum value. + """ + self._assert_cpu() + options = _pc().ScalarAggregateOptions(**kwargs) + return _pc().call_function('sum', [self], options) + + def unique(self): + """ + Compute distinct elements in array. + + Returns + ------- + unique : Array + An array of the same data type, with deduplicated elements. + """ + self._assert_cpu() + return _pc().call_function('unique', [self]) + + def dictionary_encode(self, null_encoding='mask'): + """ + Compute dictionary-encoded representation of array. + + See :func:`pyarrow.compute.dictionary_encode` for full usage. + + Parameters + ---------- + null_encoding : str, default "mask" + How to handle null entries. + + Returns + ------- + encoded : DictionaryArray + A dictionary-encoded version of this array. + """ + self._assert_cpu() + options = _pc().DictionaryEncodeOptions(null_encoding) + return _pc().call_function('dictionary_encode', [self], options) + + def value_counts(self): + """ + Compute counts of unique elements in array. + + Returns + ------- + StructArray + An array of structs + """ + self._assert_cpu() + return _pc().call_function('value_counts', [self]) + + @staticmethod + def from_pandas(obj, mask=None, type=None, bint safe=True, + MemoryPool memory_pool=None): + """ + Convert pandas.Series to an Arrow Array. + + This method uses Pandas semantics about what values indicate + nulls. See pyarrow.array for more general conversion from arrays or + sequences to Arrow arrays. + + Parameters + ---------- + obj : ndarray, pandas.Series, array-like + mask : array (boolean), optional + Indicate which values are null (True) or not null (False). + type : pyarrow.DataType + Explicit type to attempt to coerce to, otherwise will be inferred + from the data. + safe : bool, default True + Check for overflows or other unsafe conversions. + memory_pool : pyarrow.MemoryPool, optional + If not passed, will allocate memory from the currently-set default + memory pool. + + Notes + ----- + Localized timestamps will currently be returned as UTC (pandas's native + representation). Timezone-naive data will be implicitly interpreted as + UTC. + + Returns + ------- + array : pyarrow.Array or pyarrow.ChunkedArray + ChunkedArray is returned if object data overflows binary buffer. + """ + return array(obj, mask=mask, type=type, safe=safe, from_pandas=True, + memory_pool=memory_pool) + + def __reduce__(self): + self._assert_cpu() + return _restore_array, \ + (_reduce_array_data(self.sp_array.get().data().get()),) + + @staticmethod + def from_buffers(DataType type, length, buffers, null_count=-1, offset=0, + children=None): + """ + Construct an Array from a sequence of buffers. + + The concrete type returned depends on the datatype. + + Parameters + ---------- + type : DataType + The value type of the array. + length : int + The number of values in the array. + buffers : List[Buffer] + The buffers backing this array. + null_count : int, default -1 + The number of null entries in the array. Negative value means that + the null count is not known. + offset : int, default 0 + The array's logical offset (in values, not in bytes) from the + start of each buffer. + children : List[Array], default None + Nested type children with length matching type.num_fields. + + Returns + ------- + array : Array + """ + cdef: + Buffer buf + Array child + vector[shared_ptr[CBuffer]] c_buffers + vector[shared_ptr[CArrayData]] c_child_data + shared_ptr[CArrayData] array_data + + children = children or [] + + if type.num_fields != len(children): + raise ValueError("Type's expected number of children " + "({0}) did not match the passed number " + "({1}).".format(type.num_fields, len(children))) + + if type.has_variadic_buffers: + if type.num_buffers > len(buffers): + raise ValueError("Type's expected number of buffers is at least " + "{0}, but the passed number is " + "{1}.".format(type.num_buffers, len(buffers))) + elif type.num_buffers != len(buffers): + raise ValueError("Type's expected number of buffers " + "({0}) did not match the passed number " + "({1}).".format(type.num_buffers, len(buffers))) + + for buf in buffers: + # None will produce a null buffer pointer + c_buffers.push_back(pyarrow_unwrap_buffer(buf)) + + for child in children: + c_child_data.push_back(child.ap.data()) + + array_data = CArrayData.MakeWithChildren(type.sp_type, length, + c_buffers, c_child_data, + null_count, offset) + cdef Array result = pyarrow_wrap_array(MakeArray(array_data)) + result.validate() + return result + + @property + def null_count(self): + self._assert_cpu() + return self.sp_array.get().null_count() + + @property + def nbytes(self): + """ + Total number of bytes consumed by the elements of the array. + + In other words, the sum of bytes from all buffer + ranges referenced. + + Unlike `get_total_buffer_size` this method will account for array + offsets. + + If buffers are shared between arrays then the shared + portion will be counted multiple times. + + The dictionary of dictionary arrays will always be counted in their + entirety even if the array only references a portion of the dictionary. + """ + self._assert_cpu() + cdef CResult[int64_t] c_size_res + with nogil: + c_size_res = ReferencedBufferSize(deref(self.ap)) + size = GetResultValue(c_size_res) + return size + + def get_total_buffer_size(self): + """ + The sum of bytes in each buffer referenced by the array. + + An array may only reference a portion of a buffer. + This method will overestimate in this case and return the + byte size of the entire buffer. + + If a buffer is referenced multiple times then it will + only be counted once. + """ + self._assert_cpu() + cdef int64_t total_buffer_size + total_buffer_size = TotalBufferSize(deref(self.ap)) + return total_buffer_size + + def __sizeof__(self): + self._assert_cpu() + return super(Array, self).__sizeof__() + self.nbytes + + def __iter__(self): + self._assert_cpu() + for i in range(len(self)): + yield self.getitem(i) + + def __repr__(self): + type_format = object.__repr__(self) + return '{0}\n{1}'.format(type_format, str(self)) + + def to_string(self, *, int indent=2, int top_level_indent=0, int window=10, + int container_window=2, c_bool skip_new_lines=False): + """ + Render a "pretty-printed" string representation of the Array. + + Note: for data on a non-CPU device, the full array is copied to CPU + memory. + + Parameters + ---------- + indent : int, default 2 + How much to indent the internal items in the string to + the right, by default ``2``. + top_level_indent : int, default 0 + How much to indent right the entire content of the array, + by default ``0``. + window : int + How many primitive items to preview at the begin and end + of the array when the array is bigger than the window. + The other items will be ellipsed. + container_window : int + How many container items (such as a list in a list array) + to preview at the begin and end of the array when the array + is bigger than the window. + skip_new_lines : bool + If the array should be rendered as a single line of text + or if each element should be on its own line. + """ + cdef: + c_string result + PrettyPrintOptions options + + with nogil: + options = PrettyPrintOptions(top_level_indent, window) + options.skip_new_lines = skip_new_lines + options.indent_size = indent + check_status( + PrettyPrint( + deref(self.ap), + options, + &result + ) + ) + + return frombytes(result, safe=True) + + def format(self, **kwargs): + """ + DEPRECATED, use pyarrow.Array.to_string + + Parameters + ---------- + **kwargs : dict + + Returns + ------- + str + """ + import warnings + warnings.warn('Array.format is deprecated, use Array.to_string') + return self.to_string(**kwargs) + + def __str__(self): + return self.to_string() + + def __eq__(self, other): + try: + return self.equals(other) + except TypeError: + # This also handles comparing with None + # as Array.equals(None) raises a TypeError. + return NotImplemented + + def equals(Array self, Array other not None): + """ + Parameters + ---------- + other : pyarrow.Array + + Returns + ------- + bool + """ + self._assert_cpu() + other._assert_cpu() + return self.ap.Equals(deref(other.ap)) + + def __len__(self): + return self.length() + + cdef int64_t length(self): + if self.sp_array.get(): + return self.sp_array.get().length() + else: + return 0 + + def is_null(self, *, nan_is_null=False): + """ + Return BooleanArray indicating the null values. + + Parameters + ---------- + nan_is_null : bool (optional, default False) + Whether floating-point NaN values should also be considered null. + + Returns + ------- + array : boolean Array + """ + self._assert_cpu() + options = _pc().NullOptions(nan_is_null=nan_is_null) + return _pc().call_function('is_null', [self], options) + + def is_nan(self): + """ + Return BooleanArray indicating the NaN values. + + Returns + ------- + array : boolean Array + """ + self._assert_cpu() + return _pc().call_function('is_nan', [self]) + + def is_valid(self): + """ + Return BooleanArray indicating the non-null values. + """ + self._assert_cpu() + return _pc().is_valid(self) + + def fill_null(self, fill_value): + """ + See :func:`pyarrow.compute.fill_null` for usage. + + Parameters + ---------- + fill_value : any + The replacement value for null entries. + + Returns + ------- + result : Array + A new array with nulls replaced by the given value. + """ + self._assert_cpu() + return _pc().fill_null(self, fill_value) + + def __getitem__(self, key): + """ + Slice or return value at given index + + Parameters + ---------- + key : integer or slice + Slices with step not equal to 1 (or None) will produce a copy + rather than a zero-copy view + + Returns + ------- + value : Scalar (index) or Array (slice) + """ + self._assert_cpu() + if isinstance(key, slice): + return _normalize_slice(self, key) + + return self.getitem(_normalize_index(key, self.length())) + + cdef getitem(self, int64_t i): + self._assert_cpu() + return Scalar.wrap(GetResultValue(self.ap.GetScalar(i))) + + def slice(self, offset=0, length=None): + """ + Compute zero-copy slice of this array. + + Parameters + ---------- + offset : int, default 0 + Offset from start of array to slice. + length : int, default None + Length of slice (default is until end of Array starting from + offset). + + Returns + ------- + sliced : Array + An array with the same datatype, containing the sliced values. + """ + cdef shared_ptr[CArray] result + + if offset < 0: + raise IndexError('Offset must be non-negative') + + offset = min(len(self), offset) + if length is None: + result = self.ap.Slice(offset) + else: + if length < 0: + raise ValueError('Length must be non-negative') + result = self.ap.Slice(offset, length) + + return pyarrow_wrap_array(result) + + def take(self, object indices): + """ + Select values from an array. + + See :func:`pyarrow.compute.take` for full usage. + + Parameters + ---------- + indices : Array or array-like + The indices in the array whose values will be returned. + + Returns + ------- + taken : Array + An array with the same datatype, containing the taken values. + """ + self._assert_cpu() + return _pc().take(self, indices) + + def drop_null(self): + """ + Remove missing values from an array. + """ + self._assert_cpu() + return _pc().drop_null(self) + + def filter(self, object mask, *, null_selection_behavior='drop'): + """ + Select values from an array. + + See :func:`pyarrow.compute.filter` for full usage. + + Parameters + ---------- + mask : Array or array-like + The boolean mask to filter the array with. + null_selection_behavior : str, default "drop" + How nulls in the mask should be handled. + + Returns + ------- + filtered : Array + An array of the same type, with only the elements selected by + the boolean mask. + """ + self._assert_cpu() + return _pc().filter(self, mask, + null_selection_behavior=null_selection_behavior) + + def index(self, value, start=None, end=None, *, memory_pool=None): + """ + Find the first index of a value. + + See :func:`pyarrow.compute.index` for full usage. + + Parameters + ---------- + value : Scalar or object + The value to look for in the array. + start : int, optional + The start index where to look for `value`. + end : int, optional + The end index where to look for `value`. + memory_pool : MemoryPool, optional + A memory pool for potential memory allocations. + + Returns + ------- + index : Int64Scalar + The index of the value in the array (-1 if not found). + """ + self._assert_cpu() + return _pc().index(self, value, start, end, memory_pool=memory_pool) + + def sort(self, order="ascending", **kwargs): + """ + Sort the Array + + Parameters + ---------- + order : str, default "ascending" + Which order to sort values in. + Accepted values are "ascending", "descending". + **kwargs : dict, optional + Additional sorting options. + As allowed by :class:`SortOptions` + + Returns + ------- + result : Array + """ + self._assert_cpu() + indices = _pc().sort_indices( + self, + options=_pc().SortOptions(sort_keys=[("", order)], **kwargs) + ) + return self.take(indices) + + def _to_pandas(self, options, types_mapper=None, **kwargs): + self._assert_cpu() + return _array_like_to_pandas(self, options, types_mapper=types_mapper) + + def __array__(self, dtype=None, copy=None): + self._assert_cpu() + + if copy is False: + try: + values = self.to_numpy(zero_copy_only=True) + except ArrowInvalid: + raise ValueError( + "Unable to avoid a copy while creating a numpy array as requested.\n" + "If using `np.array(obj, copy=False)` replace it with " + "`np.asarray(obj)` to allow a copy when needed" + ) + # values is already a numpy array at this point, but calling np.array(..) + # again to handle the `dtype` keyword with a no-copy guarantee + return np.array(values, dtype=dtype, copy=False) + + values = self.to_numpy(zero_copy_only=False) + if copy is True and is_numeric(self.type.id) and self.null_count == 0: + # to_numpy did not yet make a copy (is_numeric = integer/floats, no decimal) + return np.array(values, dtype=dtype, copy=True) + + if dtype is None: + return values + return np.asarray(values, dtype=dtype) + + def to_numpy(self, zero_copy_only=True, writable=False): + """ + Return a NumPy view or copy of this array. + + By default, tries to return a view of this array. This is only + supported for primitive arrays with the same memory layout as NumPy + (i.e. integers, floating point, ..) and without any nulls. + + For the extension arrays, this method simply delegates to the + underlying storage array. + + Parameters + ---------- + zero_copy_only : bool, default True + If True, an exception will be raised if the conversion to a numpy + array would require copying the underlying data (e.g. in presence + of nulls, or for non-primitive types). + writable : bool, default False + For numpy arrays created with zero copy (view on the Arrow data), + the resulting array is not writable (Arrow data is immutable). + By setting this to True, a copy of the array is made to ensure + it is writable. + + Returns + ------- + array : numpy.ndarray + """ + self._assert_cpu() + + if np is None: + raise ImportError( + "Cannot return a numpy.ndarray if NumPy is not present") + cdef: + PyObject* out + PandasOptions c_options + object values + + if zero_copy_only and writable: + raise ValueError( + "Cannot return a writable array if asking for zero-copy") + + # If there are nulls and the array is a DictionaryArray + # decoding the dictionary will make sure nulls are correctly handled. + # Decoding a dictionary does imply a copy by the way, + # so it can't be done if the user requested a zero_copy. + c_options.decode_dictionaries = True + c_options.zero_copy_only = zero_copy_only + c_options.to_numpy = True + + with nogil: + check_status(ConvertArrayToPandas(c_options, self.sp_array, + self, &out)) + + # wrap_array_output uses pandas to convert to Categorical, here + # always convert to numpy array without pandas dependency + array = PyObject_to_object(out) + + if writable and not array.flags.writeable: + # if the conversion already needed to a copy, writeable is True + array = array.copy() + return array + + def to_pylist(self): + """ + Convert to a list of native Python objects. + + Returns + ------- + lst : list + """ + self._assert_cpu() + return [x.as_py() for x in self] + + def tolist(self): + """ + Alias of to_pylist for compatibility with NumPy. + """ + return self.to_pylist() + + def validate(self, *, full=False): + """ + Perform validation checks. An exception is raised if validation fails. + + By default only cheap validation checks are run. Pass `full=True` + for thorough validation checks (potentially O(n)). + + Parameters + ---------- + full : bool, default False + If True, run expensive checks, otherwise cheap checks only. + + Raises + ------ + ArrowInvalid + """ + if full: + self._assert_cpu() + with nogil: + check_status(self.ap.ValidateFull()) + else: + with nogil: + check_status(self.ap.Validate()) + + @property + def offset(self): + """ + A relative position into another array's data. + + The purpose is to enable zero-copy slicing. This value defaults to zero + but must be applied on all operations with the physical storage + buffers. + """ + return self.sp_array.get().offset() + + def buffers(self): + """ + Return a list of Buffer objects pointing to this array's physical + storage. + + To correctly interpret these buffers, you need to also apply the offset + multiplied with the size of the stored data type. + """ + res = [] + _append_array_buffers(self.sp_array.get().data().get(), res) + return res + + def copy_to(self, destination): + """ + Construct a copy of the array with all buffers on destination + device. + + This method recursively copies the array's buffers and those of its + children onto the destination MemoryManager device and returns the + new Array. + + Parameters + ---------- + destination : pyarrow.MemoryManager or pyarrow.Device + The destination device to copy the array to. + + Returns + ------- + Array + """ + cdef: + shared_ptr[CArray] c_array + shared_ptr[CMemoryManager] c_memory_manager + + if isinstance(destination, Device): + c_memory_manager = (destination).unwrap().get().default_memory_manager() + elif isinstance(destination, MemoryManager): + c_memory_manager = (destination).unwrap() + else: + raise TypeError( + "Argument 'destination' has incorrect type (expected a " + f"pyarrow Device or MemoryManager, got {type(destination)})" + ) + + with nogil: + c_array = GetResultValue(self.ap.CopyTo(c_memory_manager)) + return pyarrow_wrap_array(c_array) + + def _export_to_c(self, out_ptr, out_schema_ptr=0): + """ + Export to a C ArrowArray struct, given its pointer. + + If a C ArrowSchema struct pointer is also given, the array type + is exported to it at the same time. + + Parameters + ---------- + out_ptr: int + The raw pointer to a C ArrowArray struct. + out_schema_ptr: int (optional) + The raw pointer to a C ArrowSchema struct. + + Be careful: if you don't pass the ArrowArray struct to a consumer, + array memory will leak. This is a low-level function intended for + expert users. + """ + cdef: + void* c_ptr = _as_c_pointer(out_ptr) + void* c_schema_ptr = _as_c_pointer(out_schema_ptr, + allow_null=True) + with nogil: + check_status(ExportArray(deref(self.sp_array), + c_ptr, + c_schema_ptr)) + + @staticmethod + def _import_from_c(in_ptr, type): + """ + Import Array from a C ArrowArray struct, given its pointer + and the imported array type. + + Parameters + ---------- + in_ptr: int + The raw pointer to a C ArrowArray struct. + type: DataType or int + Either a DataType object, or the raw pointer to a C ArrowSchema + struct. + + This is a low-level function intended for expert users. + """ + cdef: + void* c_ptr = _as_c_pointer(in_ptr) + void* c_type_ptr + shared_ptr[CArray] c_array + + c_type = pyarrow_unwrap_data_type(type) + if c_type == nullptr: + # Not a DataType object, perhaps a raw ArrowSchema pointer + c_type_ptr = _as_c_pointer(type) + with nogil: + c_array = GetResultValue(ImportArray( + c_ptr, c_type_ptr)) + else: + with nogil: + c_array = GetResultValue(ImportArray( c_ptr, + c_type)) + return pyarrow_wrap_array(c_array) + + def __arrow_c_array__(self, requested_schema=None): + """ + Get a pair of PyCapsules containing a C ArrowArray representation of the object. + + Parameters + ---------- + requested_schema : PyCapsule | None + A PyCapsule containing a C ArrowSchema representation of a requested + schema. PyArrow will attempt to cast the array to this data type. + If None, the array will be returned as-is, with a type matching the + one returned by :meth:`__arrow_c_schema__()`. + + Returns + ------- + Tuple[PyCapsule, PyCapsule] + A pair of PyCapsules containing a C ArrowSchema and ArrowArray, + respectively. + """ + self._assert_cpu() + + cdef: + ArrowArray* c_array + ArrowSchema* c_schema + shared_ptr[CArray] inner_array + + if requested_schema is not None: + target_type = DataType._import_from_c_capsule(requested_schema) + + if target_type != self.type: + try: + casted_array = _pc().cast(self, target_type, safe=True) + inner_array = pyarrow_unwrap_array(casted_array) + except ArrowInvalid as e: + raise ValueError( + f"Could not cast {self.type} to requested type {target_type}: {e}" + ) + else: + inner_array = self.sp_array + else: + inner_array = self.sp_array + + schema_capsule = alloc_c_schema(&c_schema) + array_capsule = alloc_c_array(&c_array) + + with nogil: + check_status(ExportArray(deref(inner_array), c_array, c_schema)) + + return schema_capsule, array_capsule + + @staticmethod + def _import_from_c_capsule(schema_capsule, array_capsule): + cdef: + ArrowSchema* c_schema + ArrowArray* c_array + shared_ptr[CArray] array + + c_schema = PyCapsule_GetPointer(schema_capsule, 'arrow_schema') + c_array = PyCapsule_GetPointer(array_capsule, 'arrow_array') + + with nogil: + array = GetResultValue(ImportArray(c_array, c_schema)) + + return pyarrow_wrap_array(array) + + def _export_to_c_device(self, out_ptr, out_schema_ptr=0): + """ + Export to a C ArrowDeviceArray struct, given its pointer. + + If a C ArrowSchema struct pointer is also given, the array type + is exported to it at the same time. + + Parameters + ---------- + out_ptr: int + The raw pointer to a C ArrowDeviceArray struct. + out_schema_ptr: int (optional) + The raw pointer to a C ArrowSchema struct. + + Be careful: if you don't pass the ArrowDeviceArray struct to a consumer, + array memory will leak. This is a low-level function intended for + expert users. + """ + cdef: + void* c_ptr = _as_c_pointer(out_ptr) + void* c_schema_ptr = _as_c_pointer(out_schema_ptr, + allow_null=True) + with nogil: + check_status(ExportDeviceArray( + deref(self.sp_array), NULL, + c_ptr, c_schema_ptr)) + + @staticmethod + def _import_from_c_device(in_ptr, type): + """ + Import Array from a C ArrowDeviceArray struct, given its pointer + and the imported array type. + + Parameters + ---------- + in_ptr: int + The raw pointer to a C ArrowDeviceArray struct. + type: DataType or int + Either a DataType object, or the raw pointer to a C ArrowSchema + struct. + + This is a low-level function intended for expert users. + """ + cdef: + ArrowDeviceArray* c_device_array = _as_c_pointer(in_ptr) + void* c_type_ptr + shared_ptr[CArray] c_array + + if c_device_array.device_type == ARROW_DEVICE_CUDA: + _ensure_cuda_loaded() + + c_type = pyarrow_unwrap_data_type(type) + if c_type == nullptr: + # Not a DataType object, perhaps a raw ArrowSchema pointer + c_type_ptr = _as_c_pointer(type) + with nogil: + c_array = GetResultValue( + ImportDeviceArray(c_device_array, c_type_ptr) + ) + else: + with nogil: + c_array = GetResultValue( + ImportDeviceArray(c_device_array, c_type) + ) + return pyarrow_wrap_array(c_array) + + def __arrow_c_device_array__(self, requested_schema=None, **kwargs): + """ + Get a pair of PyCapsules containing a C ArrowDeviceArray representation + of the object. + + Parameters + ---------- + requested_schema : PyCapsule | None + A PyCapsule containing a C ArrowSchema representation of a requested + schema. PyArrow will attempt to cast the array to this data type. + If None, the array will be returned as-is, with a type matching the + one returned by :meth:`__arrow_c_schema__()`. + kwargs + Currently no additional keyword arguments are supported, but + this method will accept any keyword with a value of ``None`` + for compatibility with future keywords. + + Returns + ------- + Tuple[PyCapsule, PyCapsule] + A pair of PyCapsules containing a C ArrowSchema and ArrowDeviceArray, + respectively. + """ + cdef: + ArrowDeviceArray* c_array + ArrowSchema* c_schema + shared_ptr[CArray] inner_array + + non_default_kwargs = [ + name for name, value in kwargs.items() if value is not None + ] + if non_default_kwargs: + raise NotImplementedError( + f"Received unsupported keyword argument(s): {non_default_kwargs}" + ) + + if requested_schema is not None: + target_type = DataType._import_from_c_capsule(requested_schema) + + if target_type != self.type: + if not self.is_cpu: + raise NotImplementedError( + "Casting to a requested schema is only supported for CPU data" + ) + try: + casted_array = _pc().cast(self, target_type, safe=True) + inner_array = pyarrow_unwrap_array(casted_array) + except ArrowInvalid as e: + raise ValueError( + f"Could not cast {self.type} to requested type {target_type}: {e}" + ) + else: + inner_array = self.sp_array + else: + inner_array = self.sp_array + + schema_capsule = alloc_c_schema(&c_schema) + array_capsule = alloc_c_device_array(&c_array) + + with nogil: + check_status(ExportDeviceArray( + deref(inner_array), NULL, + c_array, c_schema)) + + return schema_capsule, array_capsule + + @staticmethod + def _import_from_c_device_capsule(schema_capsule, array_capsule): + cdef: + ArrowSchema* c_schema + ArrowDeviceArray* c_array + shared_ptr[CArray] array + + c_schema = PyCapsule_GetPointer(schema_capsule, 'arrow_schema') + c_array = PyCapsule_GetPointer( + array_capsule, 'arrow_device_array' + ) + + with nogil: + array = GetResultValue(ImportDeviceArray(c_array, c_schema)) + + return pyarrow_wrap_array(array) + + def __dlpack__(self, stream=None): + """Export a primitive array as a DLPack capsule. + + Parameters + ---------- + stream : int, optional + A Python integer representing a pointer to a stream. Currently not supported. + Stream is provided by the consumer to the producer to instruct the producer + to ensure that operations can safely be performed on the array. + + Returns + ------- + capsule : PyCapsule + A DLPack capsule for the array, pointing to a DLManagedTensor. + """ + if stream is None: + dlm_tensor = GetResultValue(ExportToDLPack(self.sp_array)) + + return PyCapsule_New(dlm_tensor, 'dltensor', dlpack_pycapsule_deleter) + else: + raise NotImplementedError( + "Only stream=None is supported." + ) + + def __dlpack_device__(self): + """ + Return the DLPack device tuple this arrays resides on. + + Returns + ------- + tuple : Tuple[int, int] + Tuple with index specifying the type of the device (where + CPU = 1, see cpp/src/arrow/c/dpack_abi.h) and index of the + device which is 0 by default for CPU. + """ + device = GetResultValue(ExportDevice(self.sp_array)) + return device.device_type, device.device_id + + @property + def device_type(self): + """ + The device type where the array resides. + + Returns + ------- + DeviceAllocationType + """ + return _wrap_device_allocation_type(self.sp_array.get().device_type()) + + @property + def is_cpu(self): + """ + Whether the array is CPU-accessible. + """ + return self.device_type == DeviceAllocationType.CPU + + cdef void _assert_cpu(self) except *: + if self.sp_array.get().device_type() != CDeviceAllocationType_kCPU: + raise NotImplementedError("Implemented only for data on CPU device") + + +cdef _array_like_to_pandas(obj, options, types_mapper): + cdef: + PyObject* out + PandasOptions c_options = _convert_pandas_options(options) + + original_type = obj.type + name = obj._name + dtype = None + + if types_mapper: + dtype = types_mapper(original_type) + elif original_type.id == _Type_EXTENSION: + try: + dtype = original_type.to_pandas_dtype() + except NotImplementedError: + pass + + # Only call __from_arrow__ for Arrow extension types or when explicitly + # overridden via types_mapper + if hasattr(dtype, '__from_arrow__'): + arr = dtype.__from_arrow__(obj) + return pandas_api.series(arr, name=name, copy=False) + + if pandas_api.is_v1(): + # ARROW-3789: Coerce date/timestamp types to datetime64[ns] + c_options.coerce_temporal_nanoseconds = True + + if isinstance(obj, Array): + with nogil: + check_status(ConvertArrayToPandas(c_options, + ( obj).sp_array, + obj, &out)) + elif isinstance(obj, ChunkedArray): + with nogil: + check_status(libarrow_python.ConvertChunkedArrayToPandas( + c_options, + ( obj).sp_chunked_array, + obj, &out)) + + arr = wrap_array_output(out) + + if (isinstance(original_type, TimestampType) and + options["timestamp_as_object"]): + # ARROW-5359 - need to specify object dtype to avoid pandas to + # coerce back to ns resolution + dtype = "object" + elif types_mapper: + dtype = types_mapper(original_type) + else: + dtype = None + + result = pandas_api.series(arr, dtype=dtype, name=name, copy=False) + + if (isinstance(original_type, TimestampType) and + original_type.tz is not None and + # can be object dtype for non-ns and timestamp_as_object=True + result.dtype.kind == "M"): + from pyarrow.pandas_compat import make_tz_aware + result = make_tz_aware(result, original_type.tz) + + return result + + +cdef wrap_array_output(PyObject* output): + cdef object obj = PyObject_to_object(output) + + if isinstance(obj, dict): + return _pandas_api.categorical_type.from_codes( + obj['indices'], categories=obj['dictionary'], ordered=obj['ordered'] + ) + else: + return obj + + +cdef class NullArray(Array): + """ + Concrete class for Arrow arrays of null data type. + """ + + +cdef class BooleanArray(Array): + """ + Concrete class for Arrow arrays of boolean data type. + """ + @property + def false_count(self): + return ( self.ap).false_count() + + @property + def true_count(self): + return ( self.ap).true_count() + + +cdef class NumericArray(Array): + """ + A base class for Arrow numeric arrays. + """ + + +cdef class IntegerArray(NumericArray): + """ + A base class for Arrow integer arrays. + """ + + +cdef class FloatingPointArray(NumericArray): + """ + A base class for Arrow floating-point arrays. + """ + + +cdef class Int8Array(IntegerArray): + """ + Concrete class for Arrow arrays of int8 data type. + """ + + +cdef class UInt8Array(IntegerArray): + """ + Concrete class for Arrow arrays of uint8 data type. + """ + + +cdef class Int16Array(IntegerArray): + """ + Concrete class for Arrow arrays of int16 data type. + """ + + +cdef class UInt16Array(IntegerArray): + """ + Concrete class for Arrow arrays of uint16 data type. + """ + + +cdef class Int32Array(IntegerArray): + """ + Concrete class for Arrow arrays of int32 data type. + """ + + +cdef class UInt32Array(IntegerArray): + """ + Concrete class for Arrow arrays of uint32 data type. + """ + + +cdef class Int64Array(IntegerArray): + """ + Concrete class for Arrow arrays of int64 data type. + """ + + +cdef class UInt64Array(IntegerArray): + """ + Concrete class for Arrow arrays of uint64 data type. + """ + + +cdef class Date32Array(NumericArray): + """ + Concrete class for Arrow arrays of date32 data type. + """ + + +cdef class Date64Array(NumericArray): + """ + Concrete class for Arrow arrays of date64 data type. + """ + + +cdef class TimestampArray(NumericArray): + """ + Concrete class for Arrow arrays of timestamp data type. + """ + + +cdef class Time32Array(NumericArray): + """ + Concrete class for Arrow arrays of time32 data type. + """ + + +cdef class Time64Array(NumericArray): + """ + Concrete class for Arrow arrays of time64 data type. + """ + + +cdef class DurationArray(NumericArray): + """ + Concrete class for Arrow arrays of duration data type. + """ + + +cdef class MonthDayNanoIntervalArray(Array): + """ + Concrete class for Arrow arrays of interval[MonthDayNano] type. + """ + + def to_pylist(self): + """ + Convert to a list of native Python objects. + + pyarrow.MonthDayNano is used as the native representation. + + Returns + ------- + lst : list + """ + cdef: + CResult[PyObject*] maybe_py_list + PyObject* py_list + CMonthDayNanoIntervalArray* array + array = self.sp_array.get() + maybe_py_list = MonthDayNanoIntervalArrayToPyList(deref(array)) + py_list = GetResultValue(maybe_py_list) + return PyObject_to_object(py_list) + + +cdef class HalfFloatArray(FloatingPointArray): + """ + Concrete class for Arrow arrays of float16 data type. + """ + + +cdef class FloatArray(FloatingPointArray): + """ + Concrete class for Arrow arrays of float32 data type. + """ + + +cdef class DoubleArray(FloatingPointArray): + """ + Concrete class for Arrow arrays of float64 data type. + """ + + +cdef class FixedSizeBinaryArray(Array): + """ + Concrete class for Arrow arrays of a fixed-size binary data type. + """ + +cdef class Decima32Array(FixedSizeBinaryArray): + """ + Concrete class for Arrow arrays of decimal32 data type. + """ + +cdef class Decimal64Array(FixedSizeBinaryArray): + """ + Concrete class for Arrow arrays of decimal64 data type. + """ + +cdef class Decimal128Array(FixedSizeBinaryArray): + """ + Concrete class for Arrow arrays of decimal128 data type. + """ + + +cdef class Decimal256Array(FixedSizeBinaryArray): + """ + Concrete class for Arrow arrays of decimal256 data type. + """ + +cdef class BaseListArray(Array): + + def flatten(self, recursive=False): + """ + Unnest this [Large]ListArray/[Large]ListViewArray/FixedSizeListArray + according to 'recursive'. + + Note that this method is different from ``self.values`` in that + it takes care of the slicing offset as well as null elements backed + by non-empty sub-lists. + + Parameters + ---------- + recursive : bool, default False, optional + When True, flatten this logical list-array recursively until an + array of non-list values is formed. + + When False, flatten only the top level. + + Returns + ------- + result : Array + + Examples + -------- + + Basic logical list-array's flatten + >>> import pyarrow as pa + >>> values = [1, 2, 3, 4] + >>> offsets = [2, 1, 0] + >>> sizes = [2, 2, 2] + >>> array = pa.ListViewArray.from_arrays(offsets, sizes, values) + >>> array + + [ + [ + 3, + 4 + ], + [ + 2, + 3 + ], + [ + 1, + 2 + ] + ] + >>> array.flatten() + + [ + 3, + 4, + 2, + 3, + 1, + 2 + ] + + When recursive=True, nested list arrays are flattened recursively + until an array of non-list values is formed. + + >>> array = pa.array([ + ... None, + ... [ + ... [1, None, 2], + ... None, + ... [3, 4] + ... ], + ... [], + ... [ + ... [], + ... [5, 6], + ... None + ... ], + ... [ + ... [7, 8] + ... ] + ... ], type=pa.list_(pa.list_(pa.int64()))) + >>> array.flatten(True) + + [ + 1, + null, + 2, + 3, + 4, + 5, + 6, + 7, + 8 + ] + """ + options = _pc().ListFlattenOptions(recursive) + return _pc().list_flatten(self, options=options) + + def value_parent_indices(self): + """ + Return array of same length as list child values array where each + output value is the index of the parent list array slot containing each + child value. + + Examples + -------- + >>> import pyarrow as pa + >>> arr = pa.array([[1, 2, 3], [], None, [4]], + ... type=pa.list_(pa.int32())) + >>> arr.value_parent_indices() + + [ + 0, + 0, + 0, + 3 + ] + """ + return _pc().list_parent_indices(self) + + def value_lengths(self): + """ + Return integers array with values equal to the respective length of + each list element. Null list values are null in the output. + + Examples + -------- + >>> import pyarrow as pa + >>> arr = pa.array([[1, 2, 3], [], None, [4]], + ... type=pa.list_(pa.int32())) + >>> arr.value_lengths() + + [ + 3, + 0, + null, + 1 + ] + """ + return _pc().list_value_length(self) + + +cdef class ListArray(BaseListArray): + """ + Concrete class for Arrow arrays of a list data type. + """ + + @staticmethod + def from_arrays(offsets, values, DataType type=None, MemoryPool pool=None, mask=None): + """ + Construct ListArray from arrays of int32 offsets and values. + + Parameters + ---------- + offsets : Array (int32 type) + values : Array (any type) + type : DataType, optional + If not specified, a default ListType with the values' type is + used. + pool : MemoryPool, optional + mask : Array (boolean type), optional + Indicate which values are null (True) or not null (False). + + Returns + ------- + list_array : ListArray + + Examples + -------- + >>> import pyarrow as pa + >>> values = pa.array([1, 2, 3, 4]) + >>> offsets = pa.array([0, 2, 4]) + >>> pa.ListArray.from_arrays(offsets, values) + + [ + [ + 1, + 2 + ], + [ + 3, + 4 + ] + ] + >>> # nulls in the offsets array become null lists + >>> offsets = pa.array([0, None, 2, 4]) + >>> pa.ListArray.from_arrays(offsets, values) + + [ + [ + 1, + 2 + ], + null, + [ + 3, + 4 + ] + ] + """ + cdef: + Array _offsets, _values + shared_ptr[CArray] out + shared_ptr[CBuffer] c_mask + cdef CMemoryPool* cpool = maybe_unbox_memory_pool(pool) + + _offsets = asarray(offsets, type='int32') + _values = asarray(values) + + c_mask = c_mask_inverted_from_obj(mask, pool) + + if type is not None: + with nogil: + out = GetResultValue( + CListArray.FromArraysAndType( + type.sp_type, _offsets.ap[0], _values.ap[0], cpool, c_mask)) + else: + with nogil: + out = GetResultValue( + CListArray.FromArrays( + _offsets.ap[0], _values.ap[0], cpool, c_mask)) + cdef Array result = pyarrow_wrap_array(out) + result.validate() + return result + + @property + def values(self): + """ + Return the underlying array of values which backs the ListArray + ignoring the array's offset. + + If any of the list elements are null, but are backed by a + non-empty sub-list, those elements will be included in the + output. + + Compare with :meth:`flatten`, which returns only the non-null + values taking into consideration the array's offset. + + Returns + ------- + values : Array + + See Also + -------- + ListArray.flatten : ... + + Examples + -------- + + The values include null elements from sub-lists: + + >>> import pyarrow as pa + >>> array = pa.array([[1, 2], None, [3, 4, None, 6]]) + >>> array.values + + [ + 1, + 2, + 3, + 4, + null, + 6 + ] + + If an array is sliced, the slice still uses the same + underlying data as the original array, just with an + offset. Since values ignores the offset, the values are the + same: + + >>> sliced = array.slice(1, 2) + >>> sliced + + [ + null, + [ + 3, + 4, + null, + 6 + ] + ] + >>> sliced.values + + [ + 1, + 2, + 3, + 4, + null, + 6 + ] + + """ + cdef CListArray* arr = self.ap + return pyarrow_wrap_array(arr.values()) + + @property + def offsets(self): + """ + Return the list offsets as an int32 array. + + The returned array will not have a validity bitmap, so you cannot + expect to pass it to `ListArray.from_arrays` and get back the same + list array if the original one has nulls. + + Returns + ------- + offsets : Int32Array + + Examples + -------- + >>> import pyarrow as pa + >>> array = pa.array([[1, 2], None, [3, 4, 5]]) + >>> array.offsets + + [ + 0, + 2, + 2, + 5 + ] + """ + return pyarrow_wrap_array(( self.ap).offsets()) + + +cdef class LargeListArray(BaseListArray): + """ + Concrete class for Arrow arrays of a large list data type. + + Identical to ListArray, but 64-bit offsets. + """ + + @staticmethod + def from_arrays(offsets, values, DataType type=None, MemoryPool pool=None, mask=None): + """ + Construct LargeListArray from arrays of int64 offsets and values. + + Parameters + ---------- + offsets : Array (int64 type) + values : Array (any type) + type : DataType, optional + If not specified, a default ListType with the values' type is + used. + pool : MemoryPool, optional + mask : Array (boolean type), optional + Indicate which values are null (True) or not null (False). + + Returns + ------- + list_array : LargeListArray + """ + cdef: + Array _offsets, _values + shared_ptr[CArray] out + shared_ptr[CBuffer] c_mask + + cdef CMemoryPool* cpool = maybe_unbox_memory_pool(pool) + + _offsets = asarray(offsets, type='int64') + _values = asarray(values) + + c_mask = c_mask_inverted_from_obj(mask, pool) + + if type is not None: + with nogil: + out = GetResultValue( + CLargeListArray.FromArraysAndType( + type.sp_type, _offsets.ap[0], _values.ap[0], cpool, c_mask)) + else: + with nogil: + out = GetResultValue( + CLargeListArray.FromArrays( + _offsets.ap[0], _values.ap[0], cpool, c_mask)) + cdef Array result = pyarrow_wrap_array(out) + result.validate() + return result + + @property + def values(self): + """ + Return the underlying array of values which backs the LargeListArray + ignoring the array's offset. + + If any of the list elements are null, but are backed by a + non-empty sub-list, those elements will be included in the + output. + + Compare with :meth:`flatten`, which returns only the non-null + values taking into consideration the array's offset. + + Returns + ------- + values : Array + + See Also + -------- + LargeListArray.flatten : ... + + Examples + -------- + + The values include null elements from the sub-lists: + + >>> import pyarrow as pa + >>> array = pa.array( + ... [[1, 2], None, [3, 4, None, 6]], + ... type=pa.large_list(pa.int32()), + ... ) + >>> array.values + + [ + 1, + 2, + 3, + 4, + null, + 6 + ] + + If an array is sliced, the slice still uses the same + underlying data as the original array, just with an + offset. Since values ignores the offset, the values are the + same: + + >>> sliced = array.slice(1, 2) + >>> sliced + + [ + null, + [ + 3, + 4, + null, + 6 + ] + ] + >>> sliced.values + + [ + 1, + 2, + 3, + 4, + null, + 6 + ] + """ + cdef CLargeListArray* arr = self.ap + return pyarrow_wrap_array(arr.values()) + + @property + def offsets(self): + """ + Return the list offsets as an int64 array. + + The returned array will not have a validity bitmap, so you cannot + expect to pass it to `LargeListArray.from_arrays` and get back the + same list array if the original one has nulls. + + Returns + ------- + offsets : Int64Array + """ + return pyarrow_wrap_array(( self.ap).offsets()) + + +cdef class ListViewArray(BaseListArray): + """ + Concrete class for Arrow arrays of a list view data type. + """ + + @staticmethod + def from_arrays(offsets, sizes, values, DataType type=None, MemoryPool pool=None, mask=None): + """ + Construct ListViewArray from arrays of int32 offsets, sizes, and values. + + Parameters + ---------- + offsets : Array (int32 type) + sizes : Array (int32 type) + values : Array (any type) + type : DataType, optional + If not specified, a default ListType with the values' type is + used. + pool : MemoryPool, optional + mask : Array (boolean type), optional + Indicate which values are null (True) or not null (False). + + Returns + ------- + list_view_array : ListViewArray + + Examples + -------- + >>> import pyarrow as pa + >>> values = pa.array([1, 2, 3, 4]) + >>> offsets = pa.array([0, 1, 2]) + >>> sizes = pa.array([2, 2, 2]) + >>> pa.ListViewArray.from_arrays(offsets, sizes, values) + + [ + [ + 1, + 2 + ], + [ + 2, + 3 + ], + [ + 3, + 4 + ] + ] + >>> # use a null mask to represent null values + >>> mask = pa.array([False, True, False]) + >>> pa.ListViewArray.from_arrays(offsets, sizes, values, mask=mask) + + [ + [ + 1, + 2 + ], + null, + [ + 3, + 4 + ] + ] + >>> # null values can be defined in either offsets or sizes arrays + >>> # WARNING: this will result in a copy of the offsets or sizes arrays + >>> offsets = pa.array([0, None, 2]) + >>> pa.ListViewArray.from_arrays(offsets, sizes, values) + + [ + [ + 1, + 2 + ], + null, + [ + 3, + 4 + ] + ] + """ + cdef: + Array _offsets, _sizes, _values + shared_ptr[CArray] out + shared_ptr[CBuffer] c_mask + CMemoryPool* cpool = maybe_unbox_memory_pool(pool) + + _offsets = asarray(offsets, type='int32') + _sizes = asarray(sizes, type='int32') + _values = asarray(values) + + c_mask = c_mask_inverted_from_obj(mask, pool) + + if type is not None: + with nogil: + out = GetResultValue( + CListViewArray.FromArraysAndType( + type.sp_type, _offsets.ap[0], _sizes.ap[0], _values.ap[0], cpool, c_mask)) + else: + with nogil: + out = GetResultValue( + CListViewArray.FromArrays( + _offsets.ap[0], _sizes.ap[0], _values.ap[0], cpool, c_mask)) + cdef Array result = pyarrow_wrap_array(out) + result.validate() + return result + + @property + def values(self): + """ + Return the underlying array of values which backs the ListViewArray + ignoring the array's offset and sizes. + + The values array may be out of order and/or contain additional values + that are not found in the logical representation of the array. The only + guarantee is that each non-null value in the ListView Array is contiguous. + + Compare with :meth:`flatten`, which returns only the non-null + values taking into consideration the array's order and offset. + + Returns + ------- + values : Array + + Examples + -------- + The values include null elements from sub-lists: + + >>> import pyarrow as pa + >>> values = [1, 2, None, 3, 4] + >>> offsets = [0, 0, 1] + >>> sizes = [2, 0, 4] + >>> array = pa.ListViewArray.from_arrays(offsets, sizes, values) + >>> array + + [ + [ + 1, + 2 + ], + [], + [ + 2, + null, + 3, + 4 + ] + ] + >>> array.values + + [ + 1, + 2, + null, + 3, + 4 + ] + """ + cdef CListViewArray* arr = self.ap + return pyarrow_wrap_array(arr.values()) + + @property + def offsets(self): + """ + Return the list offsets as an int32 array. + + The returned array will not have a validity bitmap, so you cannot + expect to pass it to `ListViewArray.from_arrays` and get back the same + list array if the original one has nulls. + + Returns + ------- + offsets : Int32Array + + Examples + -------- + >>> import pyarrow as pa + >>> values = [1, 2, None, 3, 4] + >>> offsets = [0, 0, 1] + >>> sizes = [2, 0, 4] + >>> array = pa.ListViewArray.from_arrays(offsets, sizes, values) + >>> array.offsets + + [ + 0, + 0, + 1 + ] + """ + return pyarrow_wrap_array(( self.ap).offsets()) + + @property + def sizes(self): + """ + Return the list sizes as an int32 array. + + The returned array will not have a validity bitmap, so you cannot + expect to pass it to `ListViewArray.from_arrays` and get back the same + list array if the original one has nulls. + + Returns + ------- + sizes : Int32Array + + Examples + -------- + >>> import pyarrow as pa + >>> values = [1, 2, None, 3, 4] + >>> offsets = [0, 0, 1] + >>> sizes = [2, 0, 4] + >>> array = pa.ListViewArray.from_arrays(offsets, sizes, values) + >>> array.sizes + + [ + 2, + 0, + 4 + ] + """ + return pyarrow_wrap_array(( self.ap).sizes()) + + +cdef class LargeListViewArray(BaseListArray): + """ + Concrete class for Arrow arrays of a large list view data type. + + Identical to ListViewArray, but with 64-bit offsets. + """ + @staticmethod + def from_arrays(offsets, sizes, values, DataType type=None, MemoryPool pool=None, mask=None): + """ + Construct LargeListViewArray from arrays of int64 offsets and values. + + Parameters + ---------- + offsets : Array (int64 type) + sizes : Array (int64 type) + values : Array (any type) + type : DataType, optional + If not specified, a default ListType with the values' type is + used. + pool : MemoryPool, optional + mask : Array (boolean type), optional + Indicate which values are null (True) or not null (False). + + Returns + ------- + list_view_array : LargeListViewArray + + Examples + -------- + >>> import pyarrow as pa + >>> values = pa.array([1, 2, 3, 4]) + >>> offsets = pa.array([0, 1, 2]) + >>> sizes = pa.array([2, 2, 2]) + >>> pa.LargeListViewArray.from_arrays(offsets, sizes, values) + + [ + [ + 1, + 2 + ], + [ + 2, + 3 + ], + [ + 3, + 4 + ] + ] + >>> # use a null mask to represent null values + >>> mask = pa.array([False, True, False]) + >>> pa.LargeListViewArray.from_arrays(offsets, sizes, values, mask=mask) + + [ + [ + 1, + 2 + ], + null, + [ + 3, + 4 + ] + ] + >>> # null values can be defined in either offsets or sizes arrays + >>> # WARNING: this will result in a copy of the offsets or sizes arrays + >>> offsets = pa.array([0, None, 2]) + >>> pa.LargeListViewArray.from_arrays(offsets, sizes, values) + + [ + [ + 1, + 2 + ], + null, + [ + 3, + 4 + ] + ] + """ + cdef: + Array _offsets, _sizes, _values + shared_ptr[CArray] out + shared_ptr[CBuffer] c_mask + CMemoryPool* cpool = maybe_unbox_memory_pool(pool) + + _offsets = asarray(offsets, type='int64') + _sizes = asarray(sizes, type='int64') + _values = asarray(values) + + c_mask = c_mask_inverted_from_obj(mask, pool) + + if type is not None: + with nogil: + out = GetResultValue( + CLargeListViewArray.FromArraysAndType( + type.sp_type, _offsets.ap[0], _sizes.ap[0], _values.ap[0], cpool, c_mask)) + else: + with nogil: + out = GetResultValue( + CLargeListViewArray.FromArrays( + _offsets.ap[0], _sizes.ap[0], _values.ap[0], cpool, c_mask)) + cdef Array result = pyarrow_wrap_array(out) + result.validate() + return result + + @property + def values(self): + """ + Return the underlying array of values which backs the LargeListArray + ignoring the array's offset. + + The values array may be out of order and/or contain additional values + that are not found in the logical representation of the array. The only + guarantee is that each non-null value in the ListView Array is contiguous. + + Compare with :meth:`flatten`, which returns only the non-null + values taking into consideration the array's order and offset. + + Returns + ------- + values : Array + + See Also + -------- + LargeListArray.flatten : ... + + Examples + -------- + + The values include null elements from sub-lists: + + >>> import pyarrow as pa + >>> values = [1, 2, None, 3, 4] + >>> offsets = [0, 0, 1] + >>> sizes = [2, 0, 4] + >>> array = pa.LargeListViewArray.from_arrays(offsets, sizes, values) + >>> array + + [ + [ + 1, + 2 + ], + [], + [ + 2, + null, + 3, + 4 + ] + ] + >>> array.values + + [ + 1, + 2, + null, + 3, + 4 + ] + """ + cdef CLargeListViewArray* arr = self.ap + return pyarrow_wrap_array(arr.values()) + + @property + def offsets(self): + """ + Return the list view offsets as an int64 array. + + The returned array will not have a validity bitmap, so you cannot + expect to pass it to `LargeListViewArray.from_arrays` and get back the + same list array if the original one has nulls. + + Returns + ------- + offsets : Int64Array + + Examples + -------- + + >>> import pyarrow as pa + >>> values = [1, 2, None, 3, 4] + >>> offsets = [0, 0, 1] + >>> sizes = [2, 0, 4] + >>> array = pa.LargeListViewArray.from_arrays(offsets, sizes, values) + >>> array.offsets + + [ + 0, + 0, + 1 + ] + """ + return pyarrow_wrap_array(( self.ap).offsets()) + + @property + def sizes(self): + """ + Return the list view sizes as an int64 array. + + The returned array will not have a validity bitmap, so you cannot + expect to pass it to `LargeListViewArray.from_arrays` and get back the + same list array if the original one has nulls. + + Returns + ------- + sizes : Int64Array + + Examples + -------- + + >>> import pyarrow as pa + >>> values = [1, 2, None, 3, 4] + >>> offsets = [0, 0, 1] + >>> sizes = [2, 0, 4] + >>> array = pa.LargeListViewArray.from_arrays(offsets, sizes, values) + >>> array.sizes + + [ + 2, + 0, + 4 + ] + """ + return pyarrow_wrap_array(( self.ap).sizes()) + + +cdef class MapArray(ListArray): + """ + Concrete class for Arrow arrays of a map data type. + """ + + @staticmethod + def from_arrays(offsets, keys, items, DataType type=None, MemoryPool pool=None, mask=None): + """ + Construct MapArray from arrays of int32 offsets and key, item arrays. + + Parameters + ---------- + offsets : array-like or sequence (int32 type) + keys : array-like or sequence (any type) + items : array-like or sequence (any type) + type : DataType, optional + If not specified, a default MapArray with the keys' and items' type is used. + pool : MemoryPool + mask : Array (boolean type), optional + Indicate which values are null (True) or not null (False). + + Returns + ------- + map_array : MapArray + + Examples + -------- + First, let's understand the structure of our dataset when viewed in a rectangular data model. + The total of 5 respondents answered the question "How much did you like the movie x?". + The value -1 in the integer array means that the value is missing. The boolean array + represents the null bitmask corresponding to the missing values in the integer array. + + >>> import pyarrow as pa + >>> movies_rectangular = np.ma.masked_array([ + ... [10, -1, -1], + ... [8, 4, 5], + ... [-1, 10, 3], + ... [-1, -1, -1], + ... [-1, -1, -1] + ... ], + ... [ + ... [False, True, True], + ... [False, False, False], + ... [True, False, False], + ... [True, True, True], + ... [True, True, True], + ... ]) + + To represent the same data with the MapArray and from_arrays, the data is + formed like this: + + >>> offsets = [ + ... 0, # -- row 1 start + ... 1, # -- row 2 start + ... 4, # -- row 3 start + ... 6, # -- row 4 start + ... 6, # -- row 5 start + ... 6, # -- row 5 end + ... ] + >>> movies = [ + ... "Dark Knight", # ---------------------------------- row 1 + ... "Dark Knight", "Meet the Parents", "Superman", # -- row 2 + ... "Meet the Parents", "Superman", # ----------------- row 3 + ... ] + >>> likings = [ + ... 10, # -------- row 1 + ... 8, 4, 5, # --- row 2 + ... 10, 3 # ------ row 3 + ... ] + >>> pa.MapArray.from_arrays(offsets, movies, likings).to_pandas() + 0 [(Dark Knight, 10)] + 1 [(Dark Knight, 8), (Meet the Parents, 4), (Sup... + 2 [(Meet the Parents, 10), (Superman, 3)] + 3 [] + 4 [] + dtype: object + + If the data in the empty rows needs to be marked as missing, it's possible + to do so by modifying the offsets argument, so that we specify `None` as + the starting positions of the rows we want marked as missing. The end row + offset still has to refer to the existing value from keys (and values): + + >>> offsets = [ + ... 0, # ----- row 1 start + ... 1, # ----- row 2 start + ... 4, # ----- row 3 start + ... None, # -- row 4 start + ... None, # -- row 5 start + ... 6, # ----- row 5 end + ... ] + >>> pa.MapArray.from_arrays(offsets, movies, likings).to_pandas() + 0 [(Dark Knight, 10)] + 1 [(Dark Knight, 8), (Meet the Parents, 4), (Sup... + 2 [(Meet the Parents, 10), (Superman, 3)] + 3 None + 4 None + dtype: object + """ + cdef: + Array _offsets, _keys, _items + shared_ptr[CArray] out + shared_ptr[CBuffer] c_mask + cdef CMemoryPool* cpool = maybe_unbox_memory_pool(pool) + + _offsets = asarray(offsets, type='int32') + _keys = asarray(keys) + _items = asarray(items) + + c_mask = c_mask_inverted_from_obj(mask, pool) + + if type is not None: + with nogil: + out = GetResultValue( + CMapArray.FromArraysAndType( + type.sp_type, _offsets.sp_array, + _keys.sp_array, _items.sp_array, cpool, c_mask)) + else: + with nogil: + out = GetResultValue( + CMapArray.FromArrays(_offsets.sp_array, + _keys.sp_array, + _items.sp_array, cpool, c_mask)) + cdef Array result = pyarrow_wrap_array(out) + result.validate() + return result + + @property + def keys(self): + """Flattened array of keys across all maps in array""" + return pyarrow_wrap_array(( self.ap).keys()) + + @property + def items(self): + """Flattened array of items across all maps in array""" + return pyarrow_wrap_array(( self.ap).items()) + + +cdef class FixedSizeListArray(BaseListArray): + """ + Concrete class for Arrow arrays of a fixed size list data type. + """ + + @staticmethod + def from_arrays(values, list_size=None, DataType type=None, mask=None): + """ + Construct FixedSizeListArray from array of values and a list length. + + Parameters + ---------- + values : Array (any type) + list_size : int + The fixed length of the lists. + type : DataType, optional + If not specified, a default ListType with the values' type and + `list_size` length is used. + mask : Array (boolean type), optional + Indicate which values are null (True) or not null (False). + + + Returns + ------- + FixedSizeListArray + + Examples + -------- + + Create from a values array and a list size: + + >>> import pyarrow as pa + >>> values = pa.array([1, 2, 3, 4]) + >>> arr = pa.FixedSizeListArray.from_arrays(values, 2) + >>> arr + + [ + [ + 1, + 2 + ], + [ + 3, + 4 + ] + ] + + Or create from a values array, list size and matching type: + + >>> typ = pa.list_(pa.field("values", pa.int64()), 2) + >>> arr = pa.FixedSizeListArray.from_arrays(values,type=typ) + >>> arr + + [ + [ + 1, + 2 + ], + [ + 3, + 4 + ] + ] + """ + cdef: + Array _values + int32_t _list_size + CResult[shared_ptr[CArray]] c_result + + _values = asarray(values) + + c_mask = c_mask_inverted_from_obj(mask, None) + + if type is not None: + if list_size is not None: + raise ValueError("Cannot specify both list_size and type") + with nogil: + c_result = CFixedSizeListArray.FromArraysAndType( + _values.sp_array, type.sp_type, c_mask) + else: + if list_size is None: + raise ValueError("Should specify one of list_size and type") + _list_size = list_size + with nogil: + c_result = CFixedSizeListArray.FromArrays( + _values.sp_array, _list_size, c_mask) + cdef Array result = pyarrow_wrap_array(GetResultValue(c_result)) + result.validate() + return result + + @property + def values(self): + """ + Return the underlying array of values which backs the + FixedSizeListArray. + + Note even null elements are included. + + Compare with :meth:`flatten`, which returns only the non-null + sub-list values. + + Returns + ------- + values : Array + + See Also + -------- + FixedSizeListArray.flatten : ... + + Examples + -------- + >>> import pyarrow as pa + >>> array = pa.array( + ... [[1, 2], None, [3, None]], + ... type=pa.list_(pa.int32(), 2) + ... ) + >>> array.values + + [ + 1, + 2, + null, + null, + 3, + null + ] + + """ + cdef CFixedSizeListArray* arr = self.ap + return pyarrow_wrap_array(arr.values()) + + +cdef class UnionArray(Array): + """ + Concrete class for Arrow arrays of a Union data type. + """ + + def child(self, int pos): + """ + DEPRECATED, use field() instead. + + Parameters + ---------- + pos : int + The physical index of the union child field (not its type code). + + Returns + ------- + field : pyarrow.Field + The given child field. + """ + import warnings + warnings.warn("child is deprecated, use field", FutureWarning) + return self.field(pos) + + def field(self, int pos): + """ + Return the given child field as an individual array. + + For sparse unions, the returned array has its offset, length, + and null count adjusted. + + For dense unions, the returned array is unchanged. + + Parameters + ---------- + pos : int + The physical index of the union child field (not its type code). + + Returns + ------- + field : Array + The given child field. + """ + cdef shared_ptr[CArray] result + result = ( self.ap).field(pos) + if result != NULL: + return pyarrow_wrap_array(result) + raise KeyError("UnionArray does not have child {}".format(pos)) + + @property + def type_codes(self): + """Get the type codes array.""" + buf = pyarrow_wrap_buffer(( self.ap).type_codes()) + return Array.from_buffers(int8(), len(self), [None, buf]) + + @property + def offsets(self): + """ + Get the value offsets array (dense arrays only). + + Does not account for any slice offset. + """ + if self.type.mode != "dense": + raise ArrowTypeError("Can only get value offsets for dense arrays") + cdef CDenseUnionArray* dense = self.ap + buf = pyarrow_wrap_buffer(dense.value_offsets()) + return Array.from_buffers(int32(), len(self), [None, buf]) + + @staticmethod + def from_dense(Array types, Array value_offsets, list children, + list field_names=None, list type_codes=None): + """ + Construct dense UnionArray from arrays of int8 types, int32 offsets and + children arrays + + Parameters + ---------- + types : Array (int8 type) + value_offsets : Array (int32 type) + children : list + field_names : list + type_codes : list + + Returns + ------- + union_array : UnionArray + """ + cdef: + shared_ptr[CArray] out + vector[shared_ptr[CArray]] c + Array child + vector[c_string] c_field_names + vector[int8_t] c_type_codes + + for child in children: + c.push_back(child.sp_array) + if field_names is not None: + for x in field_names: + c_field_names.push_back(tobytes(x)) + if type_codes is not None: + for x in type_codes: + c_type_codes.push_back(x) + + with nogil: + out = GetResultValue(CDenseUnionArray.Make( + deref(types.ap), deref(value_offsets.ap), c, c_field_names, + c_type_codes)) + + cdef Array result = pyarrow_wrap_array(out) + result.validate() + return result + + @staticmethod + def from_sparse(Array types, list children, list field_names=None, + list type_codes=None): + """ + Construct sparse UnionArray from arrays of int8 types and children + arrays + + Parameters + ---------- + types : Array (int8 type) + children : list + field_names : list + type_codes : list + + Returns + ------- + union_array : UnionArray + """ + cdef: + shared_ptr[CArray] out + vector[shared_ptr[CArray]] c + Array child + vector[c_string] c_field_names + vector[int8_t] c_type_codes + + for child in children: + c.push_back(child.sp_array) + if field_names is not None: + for x in field_names: + c_field_names.push_back(tobytes(x)) + if type_codes is not None: + for x in type_codes: + c_type_codes.push_back(x) + + with nogil: + out = GetResultValue(CSparseUnionArray.Make( + deref(types.ap), c, c_field_names, c_type_codes)) + + cdef Array result = pyarrow_wrap_array(out) + result.validate() + return result + + +cdef class StringArray(Array): + """ + Concrete class for Arrow arrays of string (or utf8) data type. + """ + + @staticmethod + def from_buffers(int length, Buffer value_offsets, Buffer data, + Buffer null_bitmap=None, int null_count=-1, + int offset=0): + """ + Construct a StringArray from value_offsets and data buffers. + If there are nulls in the data, also a null_bitmap and the matching + null_count must be passed. + + Parameters + ---------- + length : int + value_offsets : Buffer + data : Buffer + null_bitmap : Buffer, optional + null_count : int, default 0 + offset : int, default 0 + + Returns + ------- + string_array : StringArray + """ + return Array.from_buffers(utf8(), length, + [null_bitmap, value_offsets, data], + null_count, offset) + + +cdef class LargeStringArray(Array): + """ + Concrete class for Arrow arrays of large string (or utf8) data type. + """ + + @staticmethod + def from_buffers(int length, Buffer value_offsets, Buffer data, + Buffer null_bitmap=None, int null_count=-1, + int offset=0): + """ + Construct a LargeStringArray from value_offsets and data buffers. + If there are nulls in the data, also a null_bitmap and the matching + null_count must be passed. + + Parameters + ---------- + length : int + value_offsets : Buffer + data : Buffer + null_bitmap : Buffer, optional + null_count : int, default 0 + offset : int, default 0 + + Returns + ------- + string_array : StringArray + """ + return Array.from_buffers(large_utf8(), length, + [null_bitmap, value_offsets, data], + null_count, offset) + + +cdef class StringViewArray(Array): + """ + Concrete class for Arrow arrays of string (or utf8) view data type. + """ + + +cdef class BinaryArray(Array): + """ + Concrete class for Arrow arrays of variable-sized binary data type. + """ + @property + def total_values_length(self): + """ + The number of bytes from beginning to end of the data buffer addressed + by the offsets of this BinaryArray. + """ + return ( self.ap).total_values_length() + + +cdef class LargeBinaryArray(Array): + """ + Concrete class for Arrow arrays of large variable-sized binary data type. + """ + @property + def total_values_length(self): + """ + The number of bytes from beginning to end of the data buffer addressed + by the offsets of this LargeBinaryArray. + """ + return ( self.ap).total_values_length() + + +cdef class BinaryViewArray(Array): + """ + Concrete class for Arrow arrays of variable-sized binary view data type. + """ + + +cdef class DictionaryArray(Array): + """ + Concrete class for dictionary-encoded Arrow arrays. + """ + + def dictionary_encode(self): + return self + + def dictionary_decode(self): + """ + Decodes the DictionaryArray to an Array. + """ + return self.dictionary.take(self.indices) + + @property + def dictionary(self): + cdef CDictionaryArray* darr = (self.ap) + + if self._dictionary is None: + self._dictionary = pyarrow_wrap_array(darr.dictionary()) + + return self._dictionary + + @property + def indices(self): + cdef CDictionaryArray* darr = (self.ap) + + if self._indices is None: + self._indices = pyarrow_wrap_array(darr.indices()) + + return self._indices + + @staticmethod + def from_buffers(DataType type, int64_t length, buffers, Array dictionary, + int64_t null_count=-1, int64_t offset=0): + """ + Construct a DictionaryArray from buffers. + + Parameters + ---------- + type : pyarrow.DataType + length : int + The number of values in the array. + buffers : List[Buffer] + The buffers backing the indices array. + dictionary : pyarrow.Array, ndarray or pandas.Series + The array of values referenced by the indices. + null_count : int, default -1 + The number of null entries in the indices array. Negative value means that + the null count is not known. + offset : int, default 0 + The array's logical offset (in values, not in bytes) from the + start of each buffer. + + Returns + ------- + dict_array : DictionaryArray + """ + cdef: + vector[shared_ptr[CBuffer]] c_buffers + shared_ptr[CDataType] c_type + shared_ptr[CArrayData] c_data + shared_ptr[CArray] c_result + + for buf in buffers: + c_buffers.push_back(pyarrow_unwrap_buffer(buf)) + + c_type = pyarrow_unwrap_data_type(type) + + with nogil: + c_data = CArrayData.Make( + c_type, length, c_buffers, null_count, offset) + c_data.get().dictionary = dictionary.sp_array.get().data() + c_result.reset(new CDictionaryArray(c_data)) + + cdef Array result = pyarrow_wrap_array(c_result) + result.validate() + return result + + @staticmethod + def from_arrays(indices, dictionary, mask=None, bint ordered=False, + bint from_pandas=False, bint safe=True, + MemoryPool memory_pool=None): + """ + Construct a DictionaryArray from indices and values. + + Parameters + ---------- + indices : pyarrow.Array, numpy.ndarray or pandas.Series, int type + Non-negative integers referencing the dictionary values by zero + based index. + dictionary : pyarrow.Array, ndarray or pandas.Series + The array of values referenced by the indices. + mask : ndarray or pandas.Series, bool type + True values indicate that indices are actually null. + ordered : bool, default False + Set to True if the category values are ordered. + from_pandas : bool, default False + If True, the indices should be treated as though they originated in + a pandas.Categorical (null encoded as -1). + safe : bool, default True + If True, check that the dictionary indices are in range. + memory_pool : MemoryPool, default None + For memory allocations, if required, otherwise uses default pool. + + Returns + ------- + dict_array : DictionaryArray + """ + cdef: + Array _indices, _dictionary + shared_ptr[CDataType] c_type + shared_ptr[CArray] c_result + + if isinstance(indices, Array): + if mask is not None: + raise NotImplementedError( + "mask not implemented with Arrow array inputs yet") + _indices = indices + else: + if from_pandas: + _indices = _codes_to_indices(indices, mask, None, memory_pool) + else: + _indices = array(indices, mask=mask, memory_pool=memory_pool) + + if isinstance(dictionary, Array): + _dictionary = dictionary + else: + _dictionary = array(dictionary, memory_pool=memory_pool) + + if not isinstance(_indices, IntegerArray): + raise ValueError('Indices must be integer type') + + cdef c_bool c_ordered = ordered + + c_type.reset(new CDictionaryType(_indices.type.sp_type, + _dictionary.sp_array.get().type(), + c_ordered)) + + if safe: + with nogil: + c_result = GetResultValue( + CDictionaryArray.FromArrays(c_type, _indices.sp_array, + _dictionary.sp_array)) + else: + c_result.reset(new CDictionaryArray(c_type, _indices.sp_array, + _dictionary.sp_array)) + + cdef Array result = pyarrow_wrap_array(c_result) + result.validate() + return result + + +cdef class StructArray(Array): + """ + Concrete class for Arrow arrays of a struct data type. + """ + + def field(self, index): + """ + Retrieves the child array belonging to field. + + Parameters + ---------- + index : Union[int, str] + Index / position or name of the field. + + Returns + ------- + result : Array + """ + cdef: + CStructArray* arr = self.ap + shared_ptr[CArray] child + + if isinstance(index, (bytes, str)): + child = arr.GetFieldByName(tobytes(index)) + if child == nullptr: + raise KeyError(index) + elif isinstance(index, int): + child = arr.field( + _normalize_index(index, self.ap.num_fields())) + else: + raise TypeError('Expected integer or string index') + + return pyarrow_wrap_array(child) + + def _flattened_field(self, index, MemoryPool memory_pool=None): + """ + Retrieves the child array belonging to field, + accounting for the parent array null bitmap. + + Parameters + ---------- + index : Union[int, str] + Index / position or name of the field. + memory_pool : MemoryPool, default None + For memory allocations, if required, otherwise use default pool. + + Returns + ------- + result : Array + """ + cdef: + CStructArray* arr = self.ap + shared_ptr[CArray] child + CMemoryPool* pool = maybe_unbox_memory_pool(memory_pool) + + if isinstance(index, (bytes, str)): + int_index = self.type.get_field_index(index) + if int_index < 0: + raise KeyError(index) + elif isinstance(index, int): + int_index = _normalize_index(index, self.ap.num_fields()) + else: + raise TypeError('Expected integer or string index') + + child = GetResultValue(arr.GetFlattenedField(int_index, pool)) + return pyarrow_wrap_array(child) + + def flatten(self, MemoryPool memory_pool=None): + """ + Return one individual array for each field in the struct. + + Parameters + ---------- + memory_pool : MemoryPool, default None + For memory allocations, if required, otherwise use default pool. + + Returns + ------- + result : List[Array] + """ + cdef: + vector[shared_ptr[CArray]] arrays + CMemoryPool* pool = maybe_unbox_memory_pool(memory_pool) + CStructArray* sarr = self.ap + + with nogil: + arrays = GetResultValue(sarr.Flatten(pool)) + + return [pyarrow_wrap_array(arr) for arr in arrays] + + @staticmethod + def from_arrays(arrays, names=None, fields=None, mask=None, + memory_pool=None, type=None): + """ + Construct StructArray from collection of arrays representing + each field in the struct. + + Either field names, field instances or a struct type must be passed. + + Parameters + ---------- + arrays : sequence of Array + names : List[str] (optional) + Field names for each struct child. + fields : List[Field] (optional) + Field instances for each struct child. + mask : pyarrow.Array[bool] (optional) + Indicate which values are null (True) or not null (False). + memory_pool : MemoryPool (optional) + For memory allocations, if required, otherwise uses default pool. + type : pyarrow.StructType (optional) + Struct type for name and type of each child. + + Returns + ------- + result : StructArray + """ + cdef: + shared_ptr[CArray] c_array + shared_ptr[CBuffer] c_mask + vector[shared_ptr[CArray]] c_arrays + vector[c_string] c_names + vector[shared_ptr[CField]] c_fields + CResult[shared_ptr[CArray]] c_result + ssize_t num_arrays + ssize_t length + ssize_t i + Field py_field + DataType struct_type + + if fields is not None and type is not None: + raise ValueError('Must pass either fields or type, not both') + + if type is not None: + fields = [] + for field in type: + fields.append(field) + + if names is None and fields is None: + raise ValueError('Must pass either names or fields') + if names is not None and fields is not None: + raise ValueError('Must pass either names or fields, not both') + + c_mask = c_mask_inverted_from_obj(mask, memory_pool) + + arrays = [asarray(x) for x in arrays] + for arr in arrays: + c_array = pyarrow_unwrap_array(arr) + if c_array == nullptr: + raise TypeError(f"Expected Array, got {arr.__class__}") + c_arrays.push_back(c_array) + if names is not None: + for name in names: + c_names.push_back(tobytes(name)) + else: + for item in fields: + if isinstance(item, tuple): + py_field = field(*item) + else: + py_field = item + c_fields.push_back(py_field.sp_field) + + if (c_arrays.size() == 0 and c_names.size() == 0 and + c_fields.size() == 0): + # The C++ side doesn't allow this + if mask is None: + return array([], struct([])) + else: + return array([{}] * len(mask), struct([]), mask=mask) + + if names is not None: + # XXX Cannot pass "nullptr" for a shared_ptr argument: + # https://github.com/cython/cython/issues/3020 + c_result = CStructArray.MakeFromFieldNames( + c_arrays, c_names, c_mask, -1, 0) + else: + c_result = CStructArray.MakeFromFields( + c_arrays, c_fields, c_mask, -1, 0) + cdef Array result = pyarrow_wrap_array(GetResultValue(c_result)) + result.validate() + return result + + def sort(self, order="ascending", by=None, **kwargs): + """ + Sort the StructArray + + Parameters + ---------- + order : str, default "ascending" + Which order to sort values in. + Accepted values are "ascending", "descending". + by : str or None, default None + If to sort the array by one of its fields + or by the whole array. + **kwargs : dict, optional + Additional sorting options. + As allowed by :class:`SortOptions` + + Returns + ------- + result : StructArray + """ + if by is not None: + tosort, sort_keys = self._flattened_field(by), [("", order)] + else: + tosort, sort_keys = self, [(field.name, order) for field in self.type] + indices = _pc().sort_indices( + tosort, options=_pc().SortOptions(sort_keys=sort_keys, **kwargs) + ) + return self.take(indices) + + +cdef class RunEndEncodedArray(Array): + """ + Concrete class for Arrow run-end encoded arrays. + """ + + @staticmethod + def _from_arrays(type, allow_none_for_type, logical_length, run_ends, values, logical_offset): + cdef: + int64_t _logical_length + Array _run_ends + Array _values + int64_t _logical_offset + shared_ptr[CDataType] c_type + shared_ptr[CRunEndEncodedArray] ree_array + + _logical_length = logical_length + _logical_offset = logical_offset + + type = ensure_type(type, allow_none=allow_none_for_type) + if type is not None: + _run_ends = asarray(run_ends, type=type.run_end_type) + _values = asarray(values, type=type.value_type) + c_type = pyarrow_unwrap_data_type(type) + with nogil: + ree_array = GetResultValue(CRunEndEncodedArray.Make( + c_type, _logical_length, _run_ends.sp_array, _values.sp_array, _logical_offset)) + else: + _run_ends = asarray(run_ends) + _values = asarray(values) + with nogil: + ree_array = GetResultValue(CRunEndEncodedArray.MakeFromArrays( + _logical_length, _run_ends.sp_array, _values.sp_array, _logical_offset)) + cdef Array result = pyarrow_wrap_array(ree_array) + result.validate(full=True) + return result + + @staticmethod + def from_arrays(run_ends, values, type=None): + """ + Construct RunEndEncodedArray from run_ends and values arrays. + + Parameters + ---------- + run_ends : Array (int16, int32, or int64 type) + The run_ends array. + values : Array (any type) + The values array. + type : pyarrow.DataType, optional + The run_end_encoded(run_end_type, value_type) array type. + + Returns + ------- + RunEndEncodedArray + """ + logical_length = scalar(run_ends[-1]).as_py() if len(run_ends) > 0 else 0 + return RunEndEncodedArray._from_arrays(type, True, logical_length, + run_ends, values, 0) + + @staticmethod + def from_buffers(DataType type, length, buffers, null_count=-1, offset=0, + children=None): + """ + Construct a RunEndEncodedArray from all the parameters that make up an + Array. + + RunEndEncodedArrays do not have buffers, only children arrays, but this + implementation is needed to satisfy the Array interface. + + Parameters + ---------- + type : DataType + The run_end_encoded(run_end_type, value_type) type. + length : int + The logical length of the run-end encoded array. Expected to match + the last value of the run_ends array (children[0]) minus the offset. + buffers : List[Buffer] + Empty List or [None]. + null_count : int, default -1 + The number of null entries in the array. Run-end encoded arrays + are specified to not have valid bits and null_count always equals 0. + offset : int, default 0 + The array's logical offset (in values, not in bytes) from the + start of each buffer. + children : List[Array] + Nested type children containing the run_ends and values arrays. + + Returns + ------- + RunEndEncodedArray + """ + children = children or [] + + if type.num_fields != len(children): + raise ValueError("RunEndEncodedType's expected number of children " + "({0}) did not match the passed number " + "({1}).".format(type.num_fields, len(children))) + + # buffers are validated as if we needed to pass them to C++, but + # _make_from_arrays will take care of filling in the expected + # buffers array containing a single NULL buffer on the C++ side + if len(buffers) == 0: + buffers = [None] + if buffers[0] is not None: + raise ValueError("RunEndEncodedType expects None as validity " + "bitmap, buffers[0] is not None") + if type.num_buffers != len(buffers): + raise ValueError("RunEndEncodedType's expected number of buffers " + "({0}) did not match the passed number " + "({1}).".format(type.num_buffers, len(buffers))) + + # null_count is also validated as if we needed it + if null_count != -1 and null_count != 0: + raise ValueError("RunEndEncodedType's expected null_count (0) " + "did not match passed number ({0})".format(null_count)) + + return RunEndEncodedArray._from_arrays(type, False, length, children[0], + children[1], offset) + + @property + def run_ends(self): + """ + An array holding the logical indexes of each run-end. + + The physical offset to the array is applied. + """ + cdef CRunEndEncodedArray* ree_array = (self.ap) + return pyarrow_wrap_array(ree_array.run_ends()) + + @property + def values(self): + """ + An array holding the values of each run. + + The physical offset to the array is applied. + """ + cdef CRunEndEncodedArray* ree_array = (self.ap) + return pyarrow_wrap_array(ree_array.values()) + + def find_physical_offset(self): + """ + Find the physical offset of this REE array. + + This is the offset of the run that contains the value of the first + logical element of this array considering its offset. + + This function uses binary-search, so it has a O(log N) cost. + """ + cdef CRunEndEncodedArray* ree_array = (self.ap) + return ree_array.FindPhysicalOffset() + + def find_physical_length(self): + """ + Find the physical length of this REE array. + + The physical length of an REE is the number of physical values (and + run-ends) necessary to represent the logical range of values from offset + to length. + + This function uses binary-search, so it has a O(log N) cost. + """ + cdef CRunEndEncodedArray* ree_array = (self.ap) + return ree_array.FindPhysicalLength() + + +cdef class ExtensionArray(Array): + """ + Concrete class for Arrow extension arrays. + """ + + @property + def storage(self): + cdef: + CExtensionArray* ext_array = (self.ap) + + return pyarrow_wrap_array(ext_array.storage()) + + @staticmethod + def from_storage(BaseExtensionType typ, Array storage): + """ + Construct ExtensionArray from type and storage array. + + Parameters + ---------- + typ : DataType + The extension type for the result array. + storage : Array + The underlying storage for the result array. + + Returns + ------- + ext_array : ExtensionArray + """ + cdef: + shared_ptr[CExtensionArray] ext_array + + if storage.type != typ.storage_type: + raise TypeError("Incompatible storage type {0} " + "for extension type {1}".format(storage.type, typ)) + + ext_array = make_shared[CExtensionArray](typ.sp_type, storage.sp_array) + cdef Array result = pyarrow_wrap_array( ext_array) + result.validate() + return result + + +class JsonArray(ExtensionArray): + """ + Concrete class for Arrow arrays of JSON data type. + + This does not guarantee that the JSON data actually + is valid JSON. + + Examples + -------- + Define the extension type for JSON array + + >>> import pyarrow as pa + >>> json_type = pa.json_(pa.large_utf8()) + + Create an extension array + + >>> arr = [None, '{ "id":30, "values":["a", "b"] }'] + >>> storage = pa.array(arr, pa.large_utf8()) + >>> pa.ExtensionArray.from_storage(json_type, storage) + + [ + null, + "{ "id":30, "values":["a", "b"] }" + ] + """ + + +class UuidArray(ExtensionArray): + """ + Concrete class for Arrow arrays of UUID data type. + """ + + +cdef class FixedShapeTensorArray(ExtensionArray): + """ + Concrete class for fixed shape tensor extension arrays. + + Examples + -------- + Define the extension type for tensor array + + >>> import pyarrow as pa + >>> tensor_type = pa.fixed_shape_tensor(pa.int32(), [2, 2]) + + Create an extension array + + >>> arr = [[1, 2, 3, 4], [10, 20, 30, 40], [100, 200, 300, 400]] + >>> storage = pa.array(arr, pa.list_(pa.int32(), 4)) + >>> pa.ExtensionArray.from_storage(tensor_type, storage) + + [ + [ + 1, + 2, + 3, + 4 + ], + [ + 10, + 20, + 30, + 40 + ], + [ + 100, + 200, + 300, + 400 + ] + ] + """ + + def to_numpy_ndarray(self): + """ + Convert fixed shape tensor extension array to a multi-dimensional numpy.ndarray. + + The resulting ndarray will have (ndim + 1) dimensions. + The size of the first dimension will be the length of the fixed shape tensor array + and the rest of the dimensions will match the permuted shape of the fixed + shape tensor. + + The conversion is zero-copy. + + Returns + ------- + numpy.ndarray + Ndarray representing tensors in the fixed shape tensor array concatenated + along the first dimension. + """ + + return self.to_tensor().to_numpy() + + def to_tensor(self): + """ + Convert fixed shape tensor extension array to a pyarrow.Tensor. + + The resulting Tensor will have (ndim + 1) dimensions. + The size of the first dimension will be the length of the fixed shape tensor array + and the rest of the dimensions will match the permuted shape of the fixed + shape tensor. + + The conversion is zero-copy. + + Returns + ------- + pyarrow.Tensor + Tensor representing tensors in the fixed shape tensor array concatenated + along the first dimension. + """ + + cdef: + CFixedShapeTensorArray* ext_array = (self.ap) + CResult[shared_ptr[CTensor]] ctensor + with nogil: + ctensor = ext_array.ToTensor() + return pyarrow_wrap_tensor(GetResultValue(ctensor)) + + @staticmethod + def from_numpy_ndarray(obj): + """ + Convert numpy tensors (ndarrays) to a fixed shape tensor extension array. + The first dimension of ndarray will become the length of the fixed + shape tensor array. + If input array data is not contiguous a copy will be made. + + Parameters + ---------- + obj : numpy.ndarray + + Examples + -------- + >>> import pyarrow as pa + >>> import numpy as np + >>> arr = np.array( + ... [[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]], + ... dtype=np.float32) + >>> pa.FixedShapeTensorArray.from_numpy_ndarray(arr) + + [ + [ + 1, + 2, + 3, + 4, + 5, + 6 + ], + [ + 1, + 2, + 3, + 4, + 5, + 6 + ] + ] + """ + + if len(obj.shape) < 2: + raise ValueError( + "Cannot convert 1D array or scalar to fixed shape tensor array") + if np.prod(obj.shape) == 0: + raise ValueError("Expected a non-empty ndarray") + + permutation = (-np.array(obj.strides)).argsort(kind='stable') + if permutation[0] != 0: + raise ValueError('First stride needs to be largest to ensure that ' + 'individual tensor data is contiguous in memory.') + + arrow_type = from_numpy_dtype(obj.dtype) + shape = np.take(obj.shape, permutation) + values = np.ravel(obj, order="K") + + return ExtensionArray.from_storage( + fixed_shape_tensor(arrow_type, shape[1:], permutation=permutation[1:] - 1), + FixedSizeListArray.from_arrays(values, shape[1:].prod()) + ) + + +cdef class OpaqueArray(ExtensionArray): + """ + Concrete class for opaque extension arrays. + + Examples + -------- + Define the extension type for an opaque array + + >>> import pyarrow as pa + >>> opaque_type = pa.opaque( + ... pa.binary(), + ... type_name="geometry", + ... vendor_name="postgis", + ... ) + + Create an extension array + + >>> arr = [None, b"data"] + >>> storage = pa.array(arr, pa.binary()) + >>> pa.ExtensionArray.from_storage(opaque_type, storage) + + [ + null, + 64617461 + ] + """ + + +cdef class Bool8Array(ExtensionArray): + """ + Concrete class for bool8 extension arrays. + + Examples + -------- + Define the extension type for an bool8 array + + >>> import pyarrow as pa + >>> bool8_type = pa.bool8() + + Create an extension array + + >>> arr = [-1, 0, 1, 2, None] + >>> storage = pa.array(arr, pa.int8()) + >>> pa.ExtensionArray.from_storage(bool8_type, storage) + + [ + -1, + 0, + 1, + 2, + null + ] + """ + + def to_numpy(self, zero_copy_only=True, writable=False): + """ + Return a NumPy bool view or copy of this array. + + By default, tries to return a view of this array. This is only + supported for arrays without any nulls. + + Parameters + ---------- + zero_copy_only : bool, default True + If True, an exception will be raised if the conversion to a numpy + array would require copying the underlying data (e.g. in presence + of nulls). + writable : bool, default False + For numpy arrays created with zero copy (view on the Arrow data), + the resulting array is not writable (Arrow data is immutable). + By setting this to True, a copy of the array is made to ensure + it is writable. + + Returns + ------- + array : numpy.ndarray + """ + if not writable: + try: + return self.storage.to_numpy().view(np.bool_) + except ArrowInvalid as e: + if zero_copy_only: + raise e + + return _pc().not_equal(self.storage, 0).to_numpy(zero_copy_only=zero_copy_only, writable=writable) + + @staticmethod + def from_storage(Int8Array storage): + """ + Construct Bool8Array from Int8Array storage. + + Parameters + ---------- + storage : Int8Array + The underlying storage for the result array. + + Returns + ------- + bool8_array : Bool8Array + """ + return ExtensionArray.from_storage(bool8(), storage) + + @staticmethod + def from_numpy(obj): + """ + Convert numpy array to a bool8 extension array without making a copy. + The input array must be 1-dimensional, with either bool_ or int8 dtype. + + Parameters + ---------- + obj : numpy.ndarray + + Returns + ------- + bool8_array : Bool8Array + + Examples + -------- + >>> import pyarrow as pa + >>> import numpy as np + >>> arr = np.array([True, False, True], dtype=np.bool_) + >>> pa.Bool8Array.from_numpy(arr) + + [ + 1, + 0, + 1 + ] + """ + + if obj.ndim != 1: + raise ValueError(f"Cannot convert {obj.ndim}-D array to bool8 array") + + if obj.dtype not in [np.bool_, np.int8]: + raise TypeError(f"Array dtype {obj.dtype} incompatible with bool8 storage") + + storage_arr = array(obj.view(np.int8), type=int8()) + return Bool8Array.from_storage(storage_arr) + + +cdef dict _array_classes = { + _Type_NA: NullArray, + _Type_BOOL: BooleanArray, + _Type_UINT8: UInt8Array, + _Type_UINT16: UInt16Array, + _Type_UINT32: UInt32Array, + _Type_UINT64: UInt64Array, + _Type_INT8: Int8Array, + _Type_INT16: Int16Array, + _Type_INT32: Int32Array, + _Type_INT64: Int64Array, + _Type_DATE32: Date32Array, + _Type_DATE64: Date64Array, + _Type_TIMESTAMP: TimestampArray, + _Type_TIME32: Time32Array, + _Type_TIME64: Time64Array, + _Type_DURATION: DurationArray, + _Type_INTERVAL_MONTH_DAY_NANO: MonthDayNanoIntervalArray, + _Type_HALF_FLOAT: HalfFloatArray, + _Type_FLOAT: FloatArray, + _Type_DOUBLE: DoubleArray, + _Type_LIST: ListArray, + _Type_LARGE_LIST: LargeListArray, + _Type_LIST_VIEW: ListViewArray, + _Type_LARGE_LIST_VIEW: LargeListViewArray, + _Type_MAP: MapArray, + _Type_FIXED_SIZE_LIST: FixedSizeListArray, + _Type_SPARSE_UNION: UnionArray, + _Type_DENSE_UNION: UnionArray, + _Type_BINARY: BinaryArray, + _Type_STRING: StringArray, + _Type_LARGE_BINARY: LargeBinaryArray, + _Type_LARGE_STRING: LargeStringArray, + _Type_BINARY_VIEW: BinaryViewArray, + _Type_STRING_VIEW: StringViewArray, + _Type_DICTIONARY: DictionaryArray, + _Type_FIXED_SIZE_BINARY: FixedSizeBinaryArray, + _Type_DECIMAL32: Decimal32Array, + _Type_DECIMAL64: Decimal64Array, + _Type_DECIMAL128: Decimal128Array, + _Type_DECIMAL256: Decimal256Array, + _Type_STRUCT: StructArray, + _Type_RUN_END_ENCODED: RunEndEncodedArray, + _Type_EXTENSION: ExtensionArray, +} + + +cdef inline shared_ptr[CBuffer] c_mask_inverted_from_obj(object mask, MemoryPool pool) except *: + """ + Convert mask array obj to c_mask while also inverting to signify 1 for valid and 0 for null + """ + cdef shared_ptr[CBuffer] c_mask + if mask is None: + c_mask = shared_ptr[CBuffer]() + elif isinstance(mask, Array): + if mask.type.id != Type_BOOL: + raise TypeError('Mask must be a pyarrow.Array of type boolean') + if mask.null_count != 0: + raise ValueError('Mask must not contain nulls') + inverted_mask = _pc().invert(mask, memory_pool=pool) + c_mask = pyarrow_unwrap_buffer(inverted_mask.buffers()[1]) + else: + raise TypeError('Mask must be a pyarrow.Array of type boolean') + return c_mask + + +cdef object get_array_class_from_type( + const shared_ptr[CDataType]& sp_data_type): + cdef CDataType* data_type = sp_data_type.get() + if data_type == NULL: + raise ValueError('Array data type was NULL') + + if data_type.id() == _Type_EXTENSION: + py_ext_data_type = pyarrow_wrap_data_type(sp_data_type) + return py_ext_data_type.__arrow_ext_class__() + else: + return _array_classes[data_type.id()] + + +cdef object get_values(object obj, bint* is_series): + if pandas_api.is_series(obj) or pandas_api.is_index(obj): + result = pandas_api.get_values(obj) + is_series[0] = True + elif isinstance(obj, np.ndarray): + result = obj + is_series[0] = False + else: + result = pandas_api.series(obj, copy=False).values + is_series[0] = False + + return result + + +def concat_arrays(arrays, MemoryPool memory_pool=None): + """ + Concatenate the given arrays. + + The contents of the input arrays are copied into the returned array. + + Raises + ------ + ArrowInvalid + If not all of the arrays have the same type. + + Parameters + ---------- + arrays : iterable of pyarrow.Array + Arrays to concatenate, must be identically typed. + memory_pool : MemoryPool, default None + For memory allocations. If None, the default pool is used. + + Examples + -------- + >>> import pyarrow as pa + >>> arr1 = pa.array([2, 4, 5, 100]) + >>> arr2 = pa.array([2, 4]) + >>> pa.concat_arrays([arr1, arr2]) + + [ + 2, + 4, + 5, + 100, + 2, + 4 + ] + + """ + cdef: + vector[shared_ptr[CArray]] c_arrays + shared_ptr[CArray] c_concatenated + CMemoryPool* pool = maybe_unbox_memory_pool(memory_pool) + + for array in arrays: + if not isinstance(array, Array): + raise TypeError("Iterable should contain Array objects, " + "got {0} instead".format(type(array))) + c_arrays.push_back(pyarrow_unwrap_array(array)) + + with nogil: + c_concatenated = GetResultValue(Concatenate(c_arrays, pool)) + + return pyarrow_wrap_array(c_concatenated) + + +def _empty_array(DataType type): + """ + Create empty array of the given type. + """ + if type.id == Type_DICTIONARY: + arr = DictionaryArray.from_arrays( + _empty_array(type.index_type), _empty_array(type.value_type), + ordered=type.ordered) + else: + arr = array([], type=type) + return arr diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/benchmark.pxi b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/benchmark.pxi new file mode 100644 index 0000000000000000000000000000000000000000..ab251017db78706c97c7dee8044636c55c80167e --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/benchmark.pxi @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +def benchmark_PandasObjectIsNull(list obj): + Benchmark_PandasObjectIsNull(obj) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/builder.pxi b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/builder.pxi new file mode 100644 index 0000000000000000000000000000000000000000..fbab5bbdb5a0113b107a0a7db029883dabaf7f78 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/builder.pxi @@ -0,0 +1,150 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import math + + +cdef class StringBuilder(_Weakrefable): + """ + Builder class for UTF8 strings. + + This class exposes facilities for incrementally adding string values and + building the null bitmap for a pyarrow.Array (type='string'). + """ + cdef: + unique_ptr[CStringBuilder] builder + + def __cinit__(self, MemoryPool memory_pool=None): + cdef CMemoryPool* pool = maybe_unbox_memory_pool(memory_pool) + self.builder.reset(new CStringBuilder(pool)) + + def append(self, value): + """ + Append a single value to the builder. + + The value can either be a string/bytes object or a null value + (np.nan or None). + + Parameters + ---------- + value : string/bytes or np.nan/None + The value to append to the string array builder. + """ + if isinstance(value, (bytes, str)): + self.builder.get().Append(tobytes(value)) + elif value is None or math.isnan(value): + self.builder.get().AppendNull() + else: + raise TypeError('StringBuilder only accepts string objects') + + def append_values(self, values): + """ + Append all the values from an iterable. + + Parameters + ---------- + values : iterable of string/bytes or np.nan/None values + The values to append to the string array builder. + """ + for value in values: + self.append(value) + + def finish(self): + """ + Return result of builder as an Array object; also resets the builder. + + Returns + ------- + array : pyarrow.Array + """ + cdef shared_ptr[CArray] out + with nogil: + self.builder.get().Finish(&out) + return pyarrow_wrap_array(out) + + @property + def null_count(self): + return self.builder.get().null_count() + + def __len__(self): + return self.builder.get().length() + + +cdef class StringViewBuilder(_Weakrefable): + """ + Builder class for UTF8 string views. + + This class exposes facilities for incrementally adding string values and + building the null bitmap for a pyarrow.Array (type='string_view'). + """ + cdef: + unique_ptr[CStringViewBuilder] builder + + def __cinit__(self, MemoryPool memory_pool=None): + cdef CMemoryPool* pool = maybe_unbox_memory_pool(memory_pool) + self.builder.reset(new CStringViewBuilder(pool)) + + def append(self, value): + """ + Append a single value to the builder. + + The value can either be a string/bytes object or a null value + (np.nan or None). + + Parameters + ---------- + value : string/bytes or np.nan/None + The value to append to the string array builder. + """ + if isinstance(value, (bytes, str)): + self.builder.get().Append(tobytes(value)) + elif value is None or math.isnan(value): + self.builder.get().AppendNull() + else: + raise TypeError('StringViewBuilder only accepts string objects') + + def append_values(self, values): + """ + Append all the values from an iterable. + + Parameters + ---------- + values : iterable of string/bytes or np.nan/None values + The values to append to the string array builder. + """ + for value in values: + self.append(value) + + def finish(self): + """ + Return result of builder as an Array object; also resets the builder. + + Returns + ------- + array : pyarrow.Array + """ + cdef shared_ptr[CArray] out + with nogil: + self.builder.get().Finish(&out) + return pyarrow_wrap_array(out) + + @property + def null_count(self): + return self.builder.get().null_count() + + def __len__(self): + return self.builder.get().length() diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/compat.pxi b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/compat.pxi new file mode 100644 index 0000000000000000000000000000000000000000..8cf106d5609b50dd84c082dcfd36aee5b16fbee4 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/compat.pxi @@ -0,0 +1,71 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +def encode_file_path(path): + if isinstance(path, str): + # POSIX systems can handle utf-8. UTF8 is converted to utf16-le in + # libarrow + encoded_path = path.encode('utf-8') + else: + encoded_path = path + + # Windows file system requires utf-16le for file names; Arrow C++ libraries + # will convert utf8 to utf16 + return encoded_path + + +# Starting with Python 3.7, dicts are guaranteed to be insertion-ordered. +ordered_dict = dict + + +try: + import cloudpickle as pickle +except ImportError: + import pickle + + +def tobytes(o): + """ + Encode a unicode or bytes string to bytes. + + Parameters + ---------- + o : str or bytes + Input string. + """ + if isinstance(o, str): + return o.encode('utf8') + else: + return o + + +def frombytes(o, *, safe=False): + """ + Decode the given bytestring to unicode. + + Parameters + ---------- + o : bytes-like + Input object. + safe : bool, default False + If true, raise on encoding errors. + """ + if safe: + return o.decode('utf8', errors='replace') + else: + return o.decode('utf8') diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/csv.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/csv.py new file mode 100644 index 0000000000000000000000000000000000000000..1ae197f9f200f44d8a8a65851a89025f61c4d842 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/csv.py @@ -0,0 +1,22 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +from pyarrow._csv import ( # noqa + ReadOptions, ParseOptions, ConvertOptions, ISO8601, + open_csv, read_csv, CSVStreamingReader, write_csv, + WriteOptions, CSVWriter, InvalidRow) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/dataset.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..c61e13ee7580179bb93ae6882b5b2fcf5b8faa85 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/dataset.py @@ -0,0 +1,1039 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Dataset is currently unstable. APIs subject to change without notice.""" + +import pyarrow as pa +from pyarrow.util import _is_iterable, _stringify_path, _is_path_like + +try: + from pyarrow._dataset import ( # noqa + CsvFileFormat, + CsvFragmentScanOptions, + JsonFileFormat, + JsonFragmentScanOptions, + Dataset, + DatasetFactory, + DirectoryPartitioning, + FeatherFileFormat, + FilenamePartitioning, + FileFormat, + FileFragment, + FileSystemDataset, + FileSystemDatasetFactory, + FileSystemFactoryOptions, + FileWriteOptions, + Fragment, + FragmentScanOptions, + HivePartitioning, + IpcFileFormat, + IpcFileWriteOptions, + InMemoryDataset, + Partitioning, + PartitioningFactory, + Scanner, + TaggedRecordBatch, + UnionDataset, + UnionDatasetFactory, + WrittenFile, + get_partition_keys, + get_partition_keys as _get_partition_keys, # keep for backwards compatibility + _filesystemdataset_write, + ) +except ImportError as exc: + raise ImportError( + f"The pyarrow installation is not built with support for 'dataset' ({str(exc)})" + ) from None + +# keep Expression functionality exposed here for backwards compatibility +from pyarrow.compute import Expression, scalar, field # noqa + + +_orc_available = False +_orc_msg = ( + "The pyarrow installation is not built with support for the ORC file " + "format." +) + +try: + from pyarrow._dataset_orc import OrcFileFormat + _orc_available = True +except ImportError: + pass + +_parquet_available = False +_parquet_msg = ( + "The pyarrow installation is not built with support for the Parquet file " + "format." +) + +try: + from pyarrow._dataset_parquet import ( # noqa + ParquetDatasetFactory, + ParquetFactoryOptions, + ParquetFileFormat, + ParquetFileFragment, + ParquetFileWriteOptions, + ParquetFragmentScanOptions, + ParquetReadOptions, + RowGroupInfo, + ) + _parquet_available = True +except ImportError: + pass + + +try: + from pyarrow._dataset_parquet_encryption import ( # noqa + ParquetDecryptionConfig, + ParquetEncryptionConfig, + ) +except ImportError: + pass + + +def __getattr__(name): + if name == "OrcFileFormat" and not _orc_available: + raise ImportError(_orc_msg) + + if name == "ParquetFileFormat" and not _parquet_available: + raise ImportError(_parquet_msg) + + raise AttributeError( + "module 'pyarrow.dataset' has no attribute '{0}'".format(name) + ) + + +def partitioning(schema=None, field_names=None, flavor=None, + dictionaries=None): + """ + Specify a partitioning scheme. + + The supported schemes include: + + - "DirectoryPartitioning": this scheme expects one segment in the file path + for each field in the specified schema (all fields are required to be + present). For example given schema the path + "/2009/11" would be parsed to ("year"_ == 2009 and "month"_ == 11). + - "HivePartitioning": a scheme for "/$key=$value/" nested directories as + found in Apache Hive. This is a multi-level, directory based partitioning + scheme. Data is partitioned by static values of a particular column in + the schema. Partition keys are represented in the form $key=$value in + directory names. Field order is ignored, as are missing or unrecognized + field names. + For example, given schema, a possible + path would be "/year=2009/month=11/day=15" (but the field order does not + need to match). + - "FilenamePartitioning": this scheme expects the partitions will have + filenames containing the field values separated by "_". + For example, given schema, a possible + partition filename "2009_11_part-0.parquet" would be parsed + to ("year"_ == 2009 and "month"_ == 11). + + Parameters + ---------- + schema : pyarrow.Schema, default None + The schema that describes the partitions present in the file path. + If not specified, and `field_names` and/or `flavor` are specified, + the schema will be inferred from the file path (and a + PartitioningFactory is returned). + field_names : list of str, default None + A list of strings (field names). If specified, the schema's types are + inferred from the file paths (only valid for DirectoryPartitioning). + flavor : str, default None + The default is DirectoryPartitioning. Specify ``flavor="hive"`` for + a HivePartitioning, and ``flavor="filename"`` for a + FilenamePartitioning. + dictionaries : dict[str, Array] + If the type of any field of `schema` is a dictionary type, the + corresponding entry of `dictionaries` must be an array containing + every value which may be taken by the corresponding column or an + error will be raised in parsing. Alternatively, pass `infer` to have + Arrow discover the dictionary values, in which case a + PartitioningFactory is returned. + + Returns + ------- + Partitioning or PartitioningFactory + The partitioning scheme + + Examples + -------- + + Specify the Schema for paths like "/2009/June": + + >>> import pyarrow as pa + >>> import pyarrow.dataset as ds + >>> part = ds.partitioning(pa.schema([("year", pa.int16()), + ... ("month", pa.string())])) + + or let the types be inferred by only specifying the field names: + + >>> part = ds.partitioning(field_names=["year", "month"]) + + For paths like "/2009/June", the year will be inferred as int32 while month + will be inferred as string. + + Specify a Schema with dictionary encoding, providing dictionary values: + + >>> part = ds.partitioning( + ... pa.schema([ + ... ("year", pa.int16()), + ... ("month", pa.dictionary(pa.int8(), pa.string())) + ... ]), + ... dictionaries={ + ... "month": pa.array(["January", "February", "March"]), + ... }) + + Alternatively, specify a Schema with dictionary encoding, but have Arrow + infer the dictionary values: + + >>> part = ds.partitioning( + ... pa.schema([ + ... ("year", pa.int16()), + ... ("month", pa.dictionary(pa.int8(), pa.string())) + ... ]), + ... dictionaries="infer") + + Create a Hive scheme for a path like "/year=2009/month=11": + + >>> part = ds.partitioning( + ... pa.schema([("year", pa.int16()), ("month", pa.int8())]), + ... flavor="hive") + + A Hive scheme can also be discovered from the directory structure (and + types will be inferred): + + >>> part = ds.partitioning(flavor="hive") + """ + if flavor is None: + # default flavor + if schema is not None: + if field_names is not None: + raise ValueError( + "Cannot specify both 'schema' and 'field_names'") + if dictionaries == 'infer': + return DirectoryPartitioning.discover(schema=schema) + return DirectoryPartitioning(schema, dictionaries) + elif field_names is not None: + if isinstance(field_names, list): + return DirectoryPartitioning.discover(field_names) + else: + raise ValueError( + "Expected list of field names, got {}".format( + type(field_names))) + else: + raise ValueError( + "For the default directory flavor, need to specify " + "a Schema or a list of field names") + if flavor == "filename": + if schema is not None: + if field_names is not None: + raise ValueError( + "Cannot specify both 'schema' and 'field_names'") + if dictionaries == 'infer': + return FilenamePartitioning.discover(schema=schema) + return FilenamePartitioning(schema, dictionaries) + elif field_names is not None: + if isinstance(field_names, list): + return FilenamePartitioning.discover(field_names) + else: + raise ValueError( + "Expected list of field names, got {}".format( + type(field_names))) + else: + raise ValueError( + "For the filename flavor, need to specify " + "a Schema or a list of field names") + elif flavor == 'hive': + if field_names is not None: + raise ValueError("Cannot specify 'field_names' for flavor 'hive'") + elif schema is not None: + if isinstance(schema, pa.Schema): + if dictionaries == 'infer': + return HivePartitioning.discover(schema=schema) + return HivePartitioning(schema, dictionaries) + else: + raise ValueError( + "Expected Schema for 'schema', got {}".format( + type(schema))) + else: + return HivePartitioning.discover() + else: + raise ValueError("Unsupported flavor") + + +def _ensure_partitioning(scheme): + """ + Validate input and return a Partitioning(Factory). + + It passes None through if no partitioning scheme is defined. + """ + if scheme is None: + pass + elif isinstance(scheme, str): + scheme = partitioning(flavor=scheme) + elif isinstance(scheme, list): + scheme = partitioning(field_names=scheme) + elif isinstance(scheme, (Partitioning, PartitioningFactory)): + pass + else: + raise ValueError("Expected Partitioning or PartitioningFactory, got {}" + .format(type(scheme))) + return scheme + + +def _ensure_format(obj): + if isinstance(obj, FileFormat): + return obj + elif obj == "parquet": + if not _parquet_available: + raise ValueError(_parquet_msg) + return ParquetFileFormat() + elif obj in {"ipc", "arrow"}: + return IpcFileFormat() + elif obj == "feather": + return FeatherFileFormat() + elif obj == "csv": + return CsvFileFormat() + elif obj == "orc": + if not _orc_available: + raise ValueError(_orc_msg) + return OrcFileFormat() + elif obj == "json": + return JsonFileFormat() + else: + raise ValueError("format '{}' is not supported".format(obj)) + + +def _ensure_multiple_sources(paths, filesystem=None): + """ + Treat a list of paths as files belonging to a single file system + + If the file system is local then also validates that all paths + are referencing existing *files* otherwise any non-file paths will be + silently skipped (for example on a remote filesystem). + + Parameters + ---------- + paths : list of path-like + Note that URIs are not allowed. + filesystem : FileSystem or str, optional + If an URI is passed, then its path component will act as a prefix for + the file paths. + + Returns + ------- + (FileSystem, list of str) + File system object and a list of normalized paths. + + Raises + ------ + TypeError + If the passed filesystem has wrong type. + IOError + If the file system is local and a referenced path is not available or + not a file. + """ + from pyarrow.fs import ( + LocalFileSystem, SubTreeFileSystem, _MockFileSystem, FileType, + _ensure_filesystem + ) + + if filesystem is None: + # fall back to local file system as the default + filesystem = LocalFileSystem() + else: + # construct a filesystem if it is a valid URI + filesystem = _ensure_filesystem(filesystem) + + is_local = ( + isinstance(filesystem, (LocalFileSystem, _MockFileSystem)) or + (isinstance(filesystem, SubTreeFileSystem) and + isinstance(filesystem.base_fs, LocalFileSystem)) + ) + + # allow normalizing irregular paths such as Windows local paths + paths = [filesystem.normalize_path(_stringify_path(p)) for p in paths] + + # validate that all of the paths are pointing to existing *files* + # possible improvement is to group the file_infos by type and raise for + # multiple paths per error category + if is_local: + for info in filesystem.get_file_info(paths): + file_type = info.type + if file_type == FileType.File: + continue + elif file_type == FileType.NotFound: + raise FileNotFoundError(info.path) + elif file_type == FileType.Directory: + raise IsADirectoryError( + 'Path {} points to a directory, but only file paths are ' + 'supported. To construct a nested or union dataset pass ' + 'a list of dataset objects instead.'.format(info.path) + ) + else: + raise IOError( + 'Path {} exists but its type is unknown (could be a ' + 'special file such as a Unix socket or character device, ' + 'or Windows NUL / CON / ...)'.format(info.path) + ) + + return filesystem, paths + + +def _ensure_single_source(path, filesystem=None): + """ + Treat path as either a recursively traversable directory or a single file. + + Parameters + ---------- + path : path-like + filesystem : FileSystem or str, optional + If an URI is passed, then its path component will act as a prefix for + the file paths. + + Returns + ------- + (FileSystem, list of str or fs.Selector) + File system object and either a single item list pointing to a file or + an fs.Selector object pointing to a directory. + + Raises + ------ + TypeError + If the passed filesystem has wrong type. + FileNotFoundError + If the referenced file or directory doesn't exist. + """ + from pyarrow.fs import FileType, FileSelector, _resolve_filesystem_and_path + + # at this point we already checked that `path` is a path-like + filesystem, path = _resolve_filesystem_and_path(path, filesystem) + + # ensure that the path is normalized before passing to dataset discovery + path = filesystem.normalize_path(path) + + # retrieve the file descriptor + file_info = filesystem.get_file_info(path) + + # depending on the path type either return with a recursive + # directory selector or as a list containing a single file + if file_info.type == FileType.Directory: + paths_or_selector = FileSelector(path, recursive=True) + elif file_info.type == FileType.File: + paths_or_selector = [path] + else: + raise FileNotFoundError(path) + + return filesystem, paths_or_selector + + +def _filesystem_dataset(source, schema=None, filesystem=None, + partitioning=None, format=None, + partition_base_dir=None, exclude_invalid_files=None, + selector_ignore_prefixes=None): + """ + Create a FileSystemDataset which can be used to build a Dataset. + + Parameters are documented in the dataset function. + + Returns + ------- + FileSystemDataset + """ + from pyarrow.fs import LocalFileSystem, _ensure_filesystem, FileInfo + + format = _ensure_format(format or 'parquet') + partitioning = _ensure_partitioning(partitioning) + + if isinstance(source, (list, tuple)): + if source and isinstance(source[0], FileInfo): + if filesystem is None: + # fall back to local file system as the default + fs = LocalFileSystem() + else: + # construct a filesystem if it is a valid URI + fs = _ensure_filesystem(filesystem) + paths_or_selector = source + else: + fs, paths_or_selector = _ensure_multiple_sources(source, filesystem) + else: + fs, paths_or_selector = _ensure_single_source(source, filesystem) + + options = FileSystemFactoryOptions( + partitioning=partitioning, + partition_base_dir=partition_base_dir, + exclude_invalid_files=exclude_invalid_files, + selector_ignore_prefixes=selector_ignore_prefixes + ) + factory = FileSystemDatasetFactory(fs, paths_or_selector, format, options) + + return factory.finish(schema) + + +def _in_memory_dataset(source, schema=None, **kwargs): + if any(v is not None for v in kwargs.values()): + raise ValueError( + "For in-memory datasets, you cannot pass any additional arguments") + return InMemoryDataset(source, schema) + + +def _union_dataset(children, schema=None, **kwargs): + if any(v is not None for v in kwargs.values()): + raise ValueError( + "When passing a list of Datasets, you cannot pass any additional " + "arguments" + ) + + if schema is None: + # unify the children datasets' schemas + schema = pa.unify_schemas([child.schema for child in children]) + + for child in children: + if getattr(child, "_scan_options", None): + raise ValueError( + "Creating an UnionDataset from filtered or projected Datasets " + "is currently not supported. Union the unfiltered datasets " + "and apply the filter to the resulting union." + ) + + # create datasets with the requested schema + children = [child.replace_schema(schema) for child in children] + + return UnionDataset(schema, children) + + +def parquet_dataset(metadata_path, schema=None, filesystem=None, format=None, + partitioning=None, partition_base_dir=None): + """ + Create a FileSystemDataset from a `_metadata` file created via + `pyarrow.parquet.write_metadata`. + + Parameters + ---------- + metadata_path : path, + Path pointing to a single file parquet metadata file + schema : Schema, optional + Optionally provide the Schema for the Dataset, in which case it will + not be inferred from the source. + filesystem : FileSystem or URI string, default None + If a single path is given as source and filesystem is None, then the + filesystem will be inferred from the path. + If an URI string is passed, then a filesystem object is constructed + using the URI's optional path component as a directory prefix. See the + examples below. + Note that the URIs on Windows must follow 'file:///C:...' or + 'file:/C:...' patterns. + format : ParquetFileFormat + An instance of a ParquetFileFormat if special options needs to be + passed. + partitioning : Partitioning, PartitioningFactory, str, list of str + The partitioning scheme specified with the ``partitioning()`` + function. A flavor string can be used as shortcut, and with a list of + field names a DirectoryPartitioning will be inferred. + partition_base_dir : str, optional + For the purposes of applying the partitioning, paths will be + stripped of the partition_base_dir. Files not matching the + partition_base_dir prefix will be skipped for partitioning discovery. + The ignored files will still be part of the Dataset, but will not + have partition information. + + Returns + ------- + FileSystemDataset + The dataset corresponding to the given metadata + """ + from pyarrow.fs import LocalFileSystem, _ensure_filesystem + + if format is None: + format = ParquetFileFormat() + elif not isinstance(format, ParquetFileFormat): + raise ValueError("format argument must be a ParquetFileFormat") + + if filesystem is None: + filesystem = LocalFileSystem() + else: + filesystem = _ensure_filesystem(filesystem) + + metadata_path = filesystem.normalize_path(_stringify_path(metadata_path)) + options = ParquetFactoryOptions( + partition_base_dir=partition_base_dir, + partitioning=_ensure_partitioning(partitioning) + ) + + factory = ParquetDatasetFactory( + metadata_path, filesystem, format, options=options) + return factory.finish(schema) + + +def dataset(source, schema=None, format=None, filesystem=None, + partitioning=None, partition_base_dir=None, + exclude_invalid_files=None, ignore_prefixes=None): + """ + Open a dataset. + + Datasets provides functionality to efficiently work with tabular, + potentially larger than memory and multi-file dataset. + + - A unified interface for different sources, like Parquet and Feather + - Discovery of sources (crawling directories, handle directory-based + partitioned datasets, basic schema normalization) + - Optimized reading with predicate pushdown (filtering rows), projection + (selecting columns), parallel reading or fine-grained managing of tasks. + + Note that this is the high-level API, to have more control over the dataset + construction use the low-level API classes (FileSystemDataset, + FilesystemDatasetFactory, etc.) + + Parameters + ---------- + source : path, list of paths, dataset, list of datasets, (list of) \ +RecordBatch or Table, iterable of RecordBatch, RecordBatchReader, or URI + Path pointing to a single file: + Open a FileSystemDataset from a single file. + Path pointing to a directory: + The directory gets discovered recursively according to a + partitioning scheme if given. + List of file paths: + Create a FileSystemDataset from explicitly given files. The files + must be located on the same filesystem given by the filesystem + parameter. + Note that in contrary of construction from a single file, passing + URIs as paths is not allowed. + List of datasets: + A nested UnionDataset gets constructed, it allows arbitrary + composition of other datasets. + Note that additional keyword arguments are not allowed. + (List of) batches or tables, iterable of batches, or RecordBatchReader: + Create an InMemoryDataset. If an iterable or empty list is given, + a schema must also be given. If an iterable or RecordBatchReader + is given, the resulting dataset can only be scanned once; further + attempts will raise an error. + schema : Schema, optional + Optionally provide the Schema for the Dataset, in which case it will + not be inferred from the source. + format : FileFormat or str + Currently "parquet", "ipc"/"arrow"/"feather", "csv", "json", and "orc" are + supported. For Feather, only version 2 files are supported. + filesystem : FileSystem or URI string, default None + If a single path is given as source and filesystem is None, then the + filesystem will be inferred from the path. + If an URI string is passed, then a filesystem object is constructed + using the URI's optional path component as a directory prefix. See the + examples below. + Note that the URIs on Windows must follow 'file:///C:...' or + 'file:/C:...' patterns. + partitioning : Partitioning, PartitioningFactory, str, list of str + The partitioning scheme specified with the ``partitioning()`` + function. A flavor string can be used as shortcut, and with a list of + field names a DirectoryPartitioning will be inferred. + partition_base_dir : str, optional + For the purposes of applying the partitioning, paths will be + stripped of the partition_base_dir. Files not matching the + partition_base_dir prefix will be skipped for partitioning discovery. + The ignored files will still be part of the Dataset, but will not + have partition information. + exclude_invalid_files : bool, optional (default True) + If True, invalid files will be excluded (file format specific check). + This will incur IO for each files in a serial and single threaded + fashion. Disabling this feature will skip the IO, but unsupported + files may be present in the Dataset (resulting in an error at scan + time). + ignore_prefixes : list, optional + Files matching any of these prefixes will be ignored by the + discovery process. This is matched to the basename of a path. + By default this is ['.', '_']. + Note that discovery happens only if a directory is passed as source. + + Returns + ------- + dataset : Dataset + Either a FileSystemDataset or a UnionDataset depending on the source + parameter. + + Examples + -------- + Creating an example Table: + + >>> import pyarrow as pa + >>> import pyarrow.parquet as pq + >>> table = pa.table({'year': [2020, 2022, 2021, 2022, 2019, 2021], + ... 'n_legs': [2, 2, 4, 4, 5, 100], + ... 'animal': ["Flamingo", "Parrot", "Dog", "Horse", + ... "Brittle stars", "Centipede"]}) + >>> pq.write_table(table, "file.parquet") + + Opening a single file: + + >>> import pyarrow.dataset as ds + >>> dataset = ds.dataset("file.parquet", format="parquet") + >>> dataset.to_table() + pyarrow.Table + year: int64 + n_legs: int64 + animal: string + ---- + year: [[2020,2022,2021,2022,2019,2021]] + n_legs: [[2,2,4,4,5,100]] + animal: [["Flamingo","Parrot","Dog","Horse","Brittle stars","Centipede"]] + + Opening a single file with an explicit schema: + + >>> myschema = pa.schema([ + ... ('n_legs', pa.int64()), + ... ('animal', pa.string())]) + >>> dataset = ds.dataset("file.parquet", schema=myschema, format="parquet") + >>> dataset.to_table() + pyarrow.Table + n_legs: int64 + animal: string + ---- + n_legs: [[2,2,4,4,5,100]] + animal: [["Flamingo","Parrot","Dog","Horse","Brittle stars","Centipede"]] + + Opening a dataset for a single directory: + + >>> ds.write_dataset(table, "partitioned_dataset", format="parquet", + ... partitioning=['year']) + >>> dataset = ds.dataset("partitioned_dataset", format="parquet") + >>> dataset.to_table() + pyarrow.Table + n_legs: int64 + animal: string + ---- + n_legs: [[5],[2],[4,100],[2,4]] + animal: [["Brittle stars"],["Flamingo"],...["Parrot","Horse"]] + + For a single directory from a S3 bucket: + + >>> ds.dataset("s3://mybucket/nyc-taxi/", + ... format="parquet") # doctest: +SKIP + + Opening a dataset from a list of relatives local paths: + + >>> dataset = ds.dataset([ + ... "partitioned_dataset/2019/part-0.parquet", + ... "partitioned_dataset/2020/part-0.parquet", + ... "partitioned_dataset/2021/part-0.parquet", + ... ], format='parquet') + >>> dataset.to_table() + pyarrow.Table + n_legs: int64 + animal: string + ---- + n_legs: [[5],[2],[4,100]] + animal: [["Brittle stars"],["Flamingo"],["Dog","Centipede"]] + + With filesystem provided: + + >>> paths = [ + ... 'part0/data.parquet', + ... 'part1/data.parquet', + ... 'part3/data.parquet', + ... ] + >>> ds.dataset(paths, filesystem='file:///directory/prefix, + ... format='parquet') # doctest: +SKIP + + Which is equivalent with: + + >>> fs = SubTreeFileSystem("/directory/prefix", + ... LocalFileSystem()) # doctest: +SKIP + >>> ds.dataset(paths, filesystem=fs, format='parquet') # doctest: +SKIP + + With a remote filesystem URI: + + >>> paths = [ + ... 'nested/directory/part0/data.parquet', + ... 'nested/directory/part1/data.parquet', + ... 'nested/directory/part3/data.parquet', + ... ] + >>> ds.dataset(paths, filesystem='s3://bucket/', + ... format='parquet') # doctest: +SKIP + + Similarly to the local example, the directory prefix may be included in the + filesystem URI: + + >>> ds.dataset(paths, filesystem='s3://bucket/nested/directory', + ... format='parquet') # doctest: +SKIP + + Construction of a nested dataset: + + >>> ds.dataset([ + ... dataset("s3://old-taxi-data", format="parquet"), + ... dataset("local/path/to/data", format="ipc") + ... ]) # doctest: +SKIP + """ + from pyarrow.fs import FileInfo + # collect the keyword arguments for later reuse + kwargs = dict( + schema=schema, + filesystem=filesystem, + partitioning=partitioning, + format=format, + partition_base_dir=partition_base_dir, + exclude_invalid_files=exclude_invalid_files, + selector_ignore_prefixes=ignore_prefixes + ) + + if _is_path_like(source): + return _filesystem_dataset(source, **kwargs) + elif isinstance(source, (tuple, list)): + if all(_is_path_like(elem) or isinstance(elem, FileInfo) for elem in source): + return _filesystem_dataset(source, **kwargs) + elif all(isinstance(elem, Dataset) for elem in source): + return _union_dataset(source, **kwargs) + elif all(isinstance(elem, (pa.RecordBatch, pa.Table)) + for elem in source): + return _in_memory_dataset(source, **kwargs) + else: + unique_types = set(type(elem).__name__ for elem in source) + type_names = ', '.join('{}'.format(t) for t in unique_types) + raise TypeError( + 'Expected a list of path-like or dataset objects, or a list ' + 'of batches or tables. The given list contains the following ' + 'types: {}'.format(type_names) + ) + elif isinstance(source, (pa.RecordBatch, pa.Table)): + return _in_memory_dataset(source, **kwargs) + else: + raise TypeError( + 'Expected a path-like, list of path-likes or a list of Datasets ' + 'instead of the given type: {}'.format(type(source).__name__) + ) + + +def _ensure_write_partitioning(part, schema, flavor): + if isinstance(part, PartitioningFactory): + raise ValueError("A PartitioningFactory cannot be used. " + "Did you call the partitioning function " + "without supplying a schema?") + + if isinstance(part, Partitioning) and flavor: + raise ValueError( + "Providing a partitioning_flavor with " + "a Partitioning object is not supported" + ) + elif isinstance(part, (tuple, list)): + # Name of fields were provided instead of a partitioning object. + # Create a partitioning factory with those field names. + part = partitioning( + schema=pa.schema([schema.field(f) for f in part]), + flavor=flavor + ) + elif part is None: + part = partitioning(pa.schema([]), flavor=flavor) + + if not isinstance(part, Partitioning): + raise ValueError( + "partitioning must be a Partitioning object or " + "a list of column names" + ) + + return part + + +def write_dataset(data, base_dir, *, basename_template=None, format=None, + partitioning=None, partitioning_flavor=None, schema=None, + filesystem=None, file_options=None, use_threads=True, + max_partitions=None, max_open_files=None, + max_rows_per_file=None, min_rows_per_group=None, + max_rows_per_group=None, file_visitor=None, + existing_data_behavior='error', create_dir=True): + """ + Write a dataset to a given format and partitioning. + + Parameters + ---------- + data : Dataset, Table/RecordBatch, RecordBatchReader, list of \ +Table/RecordBatch, or iterable of RecordBatch + The data to write. This can be a Dataset instance or + in-memory Arrow data. If an iterable is given, the schema must + also be given. + base_dir : str + The root directory where to write the dataset. + basename_template : str, optional + A template string used to generate basenames of written data files. + The token '{i}' will be replaced with an automatically incremented + integer. If not specified, it defaults to + "part-{i}." + format.default_extname + format : FileFormat or str + The format in which to write the dataset. Currently supported: + "parquet", "ipc"/"arrow"/"feather", and "csv". If a FileSystemDataset + is being written and `format` is not specified, it defaults to the + same format as the specified FileSystemDataset. When writing a + Table or RecordBatch, this keyword is required. + partitioning : Partitioning or list[str], optional + The partitioning scheme specified with the ``partitioning()`` + function or a list of field names. When providing a list of + field names, you can use ``partitioning_flavor`` to drive which + partitioning type should be used. + partitioning_flavor : str, optional + One of the partitioning flavors supported by + ``pyarrow.dataset.partitioning``. If omitted will use the + default of ``partitioning()`` which is directory partitioning. + schema : Schema, optional + filesystem : FileSystem, optional + file_options : pyarrow.dataset.FileWriteOptions, optional + FileFormat specific write options, created using the + ``FileFormat.make_write_options()`` function. + use_threads : bool, default True + Write files in parallel. If enabled, then maximum parallelism will be + used determined by the number of available CPU cores. + max_partitions : int, default 1024 + Maximum number of partitions any batch may be written into. + max_open_files : int, default 1024 + If greater than 0 then this will limit the maximum number of + files that can be left open. If an attempt is made to open + too many files then the least recently used file will be closed. + If this setting is set too low you may end up fragmenting your + data into many small files. + max_rows_per_file : int, default 0 + Maximum number of rows per file. If greater than 0 then this will + limit how many rows are placed in any single file. Otherwise there + will be no limit and one file will be created in each output + directory unless files need to be closed to respect max_open_files + min_rows_per_group : int, default 0 + Minimum number of rows per group. When the value is greater than 0, + the dataset writer will batch incoming data and only write the row + groups to the disk when sufficient rows have accumulated. + max_rows_per_group : int, default 1024 * 1024 + Maximum number of rows per group. If the value is greater than 0, + then the dataset writer may split up large incoming batches into + multiple row groups. If this value is set, then min_rows_per_group + should also be set. Otherwise it could end up with very small row + groups. + file_visitor : function + If set, this function will be called with a WrittenFile instance + for each file created during the call. This object will have both + a path attribute and a metadata attribute. + + The path attribute will be a string containing the path to + the created file. + + The metadata attribute will be the parquet metadata of the file. + This metadata will have the file path attribute set and can be used + to build a _metadata file. The metadata attribute will be None if + the format is not parquet. + + Example visitor which simple collects the filenames created:: + + visited_paths = [] + + def file_visitor(written_file): + visited_paths.append(written_file.path) + existing_data_behavior : 'error' | 'overwrite_or_ignore' | \ +'delete_matching' + Controls how the dataset will handle data that already exists in + the destination. The default behavior ('error') is to raise an error + if any data exists in the destination. + + 'overwrite_or_ignore' will ignore any existing data and will + overwrite files with the same name as an output file. Other + existing files will be ignored. This behavior, in combination + with a unique basename_template for each write, will allow for + an append workflow. + + 'delete_matching' is useful when you are writing a partitioned + dataset. The first time each partition directory is encountered + the entire directory will be deleted. This allows you to overwrite + old partitions completely. + create_dir : bool, default True + If False, directories will not be created. This can be useful for + filesystems that do not require directories. + """ + from pyarrow.fs import _resolve_filesystem_and_path + + if isinstance(data, (list, tuple)): + schema = schema or data[0].schema + data = InMemoryDataset(data, schema=schema) + elif isinstance(data, (pa.RecordBatch, pa.Table)): + schema = schema or data.schema + data = InMemoryDataset(data, schema=schema) + elif ( + isinstance(data, pa.ipc.RecordBatchReader) + or hasattr(data, "__arrow_c_stream__") + or _is_iterable(data) + ): + data = Scanner.from_batches(data, schema=schema) + schema = None + elif not isinstance(data, (Dataset, Scanner)): + raise ValueError( + "Only Dataset, Scanner, Table/RecordBatch, RecordBatchReader, " + "a list of Tables/RecordBatches, or iterable of batches are " + "supported." + ) + + if format is None and isinstance(data, FileSystemDataset): + format = data.format + else: + format = _ensure_format(format) + + if file_options is None: + file_options = format.make_write_options() + + if format != file_options.format: + raise TypeError("Supplied FileWriteOptions have format {}, " + "which doesn't match supplied FileFormat {}".format( + format, file_options)) + + if basename_template is None: + basename_template = "part-{i}." + format.default_extname + + if max_partitions is None: + max_partitions = 1024 + + if max_open_files is None: + max_open_files = 1024 + + if max_rows_per_file is None: + max_rows_per_file = 0 + + if max_rows_per_group is None: + max_rows_per_group = 1 << 20 + + if min_rows_per_group is None: + min_rows_per_group = 0 + + # at this point data is a Scanner or a Dataset, anything else + # was converted to one of those two. So we can grab the schema + # to build the partitioning object from Dataset. + if isinstance(data, Scanner): + partitioning_schema = data.projected_schema + else: + partitioning_schema = data.schema + partitioning = _ensure_write_partitioning(partitioning, + schema=partitioning_schema, + flavor=partitioning_flavor) + + filesystem, base_dir = _resolve_filesystem_and_path(base_dir, filesystem) + + if isinstance(data, Dataset): + scanner = data.scanner(use_threads=use_threads) + else: + # scanner was passed directly by the user, in which case a schema + # cannot be passed + if schema is not None: + raise ValueError("Cannot specify a schema when writing a Scanner") + scanner = data + + _filesystemdataset_write( + scanner, base_dir, basename_template, filesystem, partitioning, + file_options, max_partitions, file_visitor, existing_data_behavior, + max_open_files, max_rows_per_file, + min_rows_per_group, max_rows_per_group, create_dir + ) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/device.pxi b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/device.pxi new file mode 100644 index 0000000000000000000000000000000000000000..26256de62093e84075a5bfc3eba9a95d12db6195 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/device.pxi @@ -0,0 +1,168 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# cython: profile=False +# distutils: language = c++ +# cython: embedsignature = True + + +cpdef enum DeviceAllocationType: + CPU = CDeviceAllocationType_kCPU + CUDA = CDeviceAllocationType_kCUDA + CUDA_HOST = CDeviceAllocationType_kCUDA_HOST + OPENCL = CDeviceAllocationType_kOPENCL + VULKAN = CDeviceAllocationType_kVULKAN + METAL = CDeviceAllocationType_kMETAL + VPI = CDeviceAllocationType_kVPI + ROCM = CDeviceAllocationType_kROCM + ROCM_HOST = CDeviceAllocationType_kROCM_HOST + EXT_DEV = CDeviceAllocationType_kEXT_DEV + CUDA_MANAGED = CDeviceAllocationType_kCUDA_MANAGED + ONEAPI = CDeviceAllocationType_kONEAPI + WEBGPU = CDeviceAllocationType_kWEBGPU + HEXAGON = CDeviceAllocationType_kHEXAGON + + +cdef object _wrap_device_allocation_type(CDeviceAllocationType device_type): + return DeviceAllocationType( device_type) + + +cdef class Device(_Weakrefable): + """ + Abstract interface for hardware devices + + This object represents a device with access to some memory spaces. + When handling a Buffer or raw memory address, it allows deciding in which + context the raw memory address should be interpreted + (e.g. CPU-accessible memory, or embedded memory on some particular GPU). + """ + + def __init__(self): + raise TypeError("Do not call Device's constructor directly, " + "use the device attribute of the MemoryManager instead.") + + cdef void init(self, const shared_ptr[CDevice]& device): + self.device = device + + @staticmethod + cdef wrap(const shared_ptr[CDevice]& device): + cdef Device self = Device.__new__(Device) + self.init(device) + return self + + cdef inline shared_ptr[CDevice] unwrap(self) nogil: + return self.device + + def __eq__(self, other): + if not isinstance(other, Device): + return False + return self.device.get().Equals(deref((other).device.get())) + + def __repr__(self): + return "".format(frombytes(self.device.get().ToString())) + + @property + def type_name(self): + """ + A shorthand for this device's type. + """ + return frombytes(self.device.get().type_name()) + + @property + def device_id(self): + """ + A device ID to identify this device if there are multiple of this type. + + If there is no "device_id" equivalent (such as for the main CPU device on + non-numa systems) returns -1. + """ + return self.device.get().device_id() + + @property + def is_cpu(self): + """ + Whether this device is the main CPU device. + + This shorthand method is very useful when deciding whether a memory address + is CPU-accessible. + """ + return self.device.get().is_cpu() + + @property + def device_type(self): + """ + Return the DeviceAllocationType of this device. + """ + return _wrap_device_allocation_type(self.device.get().device_type()) + + +cdef class MemoryManager(_Weakrefable): + """ + An object that provides memory management primitives. + + A MemoryManager is always tied to a particular Device instance. + It can also have additional parameters (such as a MemoryPool to + allocate CPU memory). + + """ + + def __init__(self): + raise TypeError("Do not call MemoryManager's constructor directly, " + "use pyarrow.default_cpu_memory_manager() instead.") + + cdef void init(self, const shared_ptr[CMemoryManager]& mm): + self.memory_manager = mm + + @staticmethod + cdef wrap(const shared_ptr[CMemoryManager]& mm): + cdef MemoryManager self = MemoryManager.__new__(MemoryManager) + self.init(mm) + return self + + cdef inline shared_ptr[CMemoryManager] unwrap(self) nogil: + return self.memory_manager + + def __repr__(self): + return "".format( + frombytes(self.memory_manager.get().device().get().ToString()) + ) + + @property + def device(self): + """ + The device this MemoryManager is tied to. + """ + return Device.wrap(self.memory_manager.get().device()) + + @property + def is_cpu(self): + """ + Whether this MemoryManager is tied to the main CPU device. + + This shorthand method is very useful when deciding whether a memory + address is CPU-accessible. + """ + return self.memory_manager.get().is_cpu() + + +def default_cpu_memory_manager(): + """ + Return the default CPU MemoryManager instance. + + The returned singleton instance uses the default MemoryPool. + """ + return MemoryManager.wrap(c_default_cpu_memory_manager()) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/feather.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/feather.py new file mode 100644 index 0000000000000000000000000000000000000000..fbd0602597006734d66a9a965ea462fb35cbe178 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/feather.py @@ -0,0 +1,277 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +import os + +from pyarrow.pandas_compat import _pandas_api # noqa +from pyarrow.lib import (Codec, Table, # noqa + concat_tables, schema) +import pyarrow.lib as ext +from pyarrow import _feather +from pyarrow._feather import FeatherError # noqa: F401 + + +class FeatherDataset: + """ + Encapsulates details of reading a list of Feather files. + + Parameters + ---------- + path_or_paths : List[str] + A list of file names + validate_schema : bool, default True + Check that individual file schemas are all the same / compatible + """ + + def __init__(self, path_or_paths, validate_schema=True): + self.paths = path_or_paths + self.validate_schema = validate_schema + + def read_table(self, columns=None): + """ + Read multiple feather files as a single pyarrow.Table + + Parameters + ---------- + columns : List[str] + Names of columns to read from the file + + Returns + ------- + pyarrow.Table + Content of the file as a table (of columns) + """ + _fil = read_table(self.paths[0], columns=columns) + self._tables = [_fil] + self.schema = _fil.schema + + for path in self.paths[1:]: + table = read_table(path, columns=columns) + if self.validate_schema: + self.validate_schemas(path, table) + self._tables.append(table) + return concat_tables(self._tables) + + def validate_schemas(self, piece, table): + if not self.schema.equals(table.schema): + raise ValueError('Schema in {!s} was different. \n' + '{!s}\n\nvs\n\n{!s}' + .format(piece, self.schema, + table.schema)) + + def read_pandas(self, columns=None, use_threads=True): + """ + Read multiple Parquet files as a single pandas DataFrame + + Parameters + ---------- + columns : List[str] + Names of columns to read from the file + use_threads : bool, default True + Use multiple threads when converting to pandas + + Returns + ------- + pandas.DataFrame + Content of the file as a pandas DataFrame (of columns) + """ + return self.read_table(columns=columns).to_pandas( + use_threads=use_threads) + + +def check_chunked_overflow(name, col): + if col.num_chunks == 1: + return + + if col.type in (ext.binary(), ext.string()): + raise ValueError("Column '{}' exceeds 2GB maximum capacity of " + "a Feather binary column. This restriction may be " + "lifted in the future".format(name)) + else: + # TODO(wesm): Not sure when else this might be reached + raise ValueError("Column '{}' of type {} was chunked on conversion " + "to Arrow and cannot be currently written to " + "Feather format".format(name, str(col.type))) + + +_FEATHER_SUPPORTED_CODECS = {'lz4', 'zstd', 'uncompressed'} + + +def write_feather(df, dest, compression=None, compression_level=None, + chunksize=None, version=2): + """ + Write a pandas.DataFrame to Feather format. + + Parameters + ---------- + df : pandas.DataFrame or pyarrow.Table + Data to write out as Feather format. + dest : str + Local destination path. + compression : string, default None + Can be one of {"zstd", "lz4", "uncompressed"}. The default of None uses + LZ4 for V2 files if it is available, otherwise uncompressed. + compression_level : int, default None + Use a compression level particular to the chosen compressor. If None + use the default compression level + chunksize : int, default None + For V2 files, the internal maximum size of Arrow RecordBatch chunks + when writing the Arrow IPC file format. None means use the default, + which is currently 64K + version : int, default 2 + Feather file version. Version 2 is the current. Version 1 is the more + limited legacy format + """ + if _pandas_api.have_pandas: + if (_pandas_api.has_sparse and + isinstance(df, _pandas_api.pd.SparseDataFrame)): + df = df.to_dense() + + if _pandas_api.is_data_frame(df): + # Feather v1 creates a new column in the resultant Table to + # store index information if index type is not RangeIndex + + if version == 1: + preserve_index = False + elif version == 2: + preserve_index = None + else: + raise ValueError("Version value should either be 1 or 2") + + table = Table.from_pandas(df, preserve_index=preserve_index) + + if version == 1: + # Version 1 does not chunking + for i, name in enumerate(table.schema.names): + col = table[i] + check_chunked_overflow(name, col) + else: + table = df + + if version == 1: + if len(table.column_names) > len(set(table.column_names)): + raise ValueError("cannot serialize duplicate column names") + + if compression is not None: + raise ValueError("Feather V1 files do not support compression " + "option") + + if chunksize is not None: + raise ValueError("Feather V1 files do not support chunksize " + "option") + else: + if compression is None and Codec.is_available('lz4_frame'): + compression = 'lz4' + elif (compression is not None and + compression not in _FEATHER_SUPPORTED_CODECS): + raise ValueError('compression="{}" not supported, must be ' + 'one of {}'.format(compression, + _FEATHER_SUPPORTED_CODECS)) + + try: + _feather.write_feather(table, dest, compression=compression, + compression_level=compression_level, + chunksize=chunksize, version=version) + except Exception: + if isinstance(dest, str): + try: + os.remove(dest) + except os.error: + pass + raise + + +def read_feather(source, columns=None, use_threads=True, + memory_map=False, **kwargs): + """ + Read a pandas.DataFrame from Feather format. To read as pyarrow.Table use + feather.read_table. + + Parameters + ---------- + source : str file path, or file-like object + You can use MemoryMappedFile as source, for explicitly use memory map. + columns : sequence, optional + Only read a specific set of columns. If not provided, all columns are + read. + use_threads : bool, default True + Whether to parallelize reading using multiple threads. If false the + restriction is used in the conversion to Pandas as well as in the + reading from Feather format. + memory_map : boolean, default False + Use memory mapping when opening file on disk, when source is a str. + **kwargs + Additional keyword arguments passed on to `pyarrow.Table.to_pandas`. + + Returns + ------- + df : pandas.DataFrame + The contents of the Feather file as a pandas.DataFrame + """ + return (read_table( + source, columns=columns, memory_map=memory_map, + use_threads=use_threads).to_pandas(use_threads=use_threads, **kwargs)) + + +def read_table(source, columns=None, memory_map=False, use_threads=True): + """ + Read a pyarrow.Table from Feather format + + Parameters + ---------- + source : str file path, or file-like object + You can use MemoryMappedFile as source, for explicitly use memory map. + columns : sequence, optional + Only read a specific set of columns. If not provided, all columns are + read. + memory_map : boolean, default False + Use memory mapping when opening file on disk, when source is a str + use_threads : bool, default True + Whether to parallelize reading using multiple threads. + + Returns + ------- + table : pyarrow.Table + The contents of the Feather file as a pyarrow.Table + """ + reader = _feather.FeatherReader( + source, use_memory_map=memory_map, use_threads=use_threads) + + if columns is None: + return reader.read() + + column_types = [type(column) for column in columns] + if all(map(lambda t: t == int, column_types)): + table = reader.read_indices(columns) + elif all(map(lambda t: t == str, column_types)): + table = reader.read_names(columns) + else: + column_type_names = [t.__name__ for t in column_types] + raise TypeError("Columns must be indices or names. " + "Got columns {} of types {}" + .format(columns, column_type_names)) + + # Feather v1 already respects the column selection + if reader.version < 3: + return table + # Feather v2 reads with sorted / deduplicated selection + elif sorted(set(columns)) == columns: + return table + else: + # follow exact order / selection of names + return table.select(columns) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/flight.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/flight.py new file mode 100644 index 0000000000000000000000000000000000000000..b1836907c6744161c86f32e873316923c60b4226 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/flight.py @@ -0,0 +1,69 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +try: + from pyarrow._flight import ( # noqa:F401 + connect, + Action, + ActionType, + BasicAuth, + CallInfo, + CertKeyPair, + ClientAuthHandler, + ClientMiddleware, + ClientMiddlewareFactory, + DescriptorType, + FlightCallOptions, + FlightCancelledError, + FlightClient, + FlightDataStream, + FlightDescriptor, + FlightEndpoint, + FlightError, + FlightInfo, + FlightInternalError, + FlightMetadataReader, + FlightMetadataWriter, + FlightMethod, + FlightServerBase, + FlightServerError, + FlightStreamChunk, + FlightStreamReader, + FlightStreamWriter, + FlightTimedOutError, + FlightUnauthenticatedError, + FlightUnauthorizedError, + FlightUnavailableError, + FlightWriteSizeExceededError, + GeneratorStream, + Location, + MetadataRecordBatchReader, + MetadataRecordBatchWriter, + RecordBatchStream, + Result, + SchemaResult, + ServerAuthHandler, + ServerCallContext, + ServerMiddleware, + ServerMiddlewareFactory, + Ticket, + TracingServerMiddlewareFactory, + ) +except ImportError as exc: + raise ImportError( + f"The pyarrow installation is not built with support for 'flight' ({str(exc)})" + ) from None diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/gandiva.pyx b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/gandiva.pyx new file mode 100644 index 0000000000000000000000000000000000000000..2202ec64f29628d76143759220eb61102d1bea97 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/gandiva.pyx @@ -0,0 +1,760 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# cython: profile=False +# distutils: language = c++ +# cython: language_level = 3 + +from libcpp.memory cimport shared_ptr +from libcpp.string cimport string as c_string +from libcpp.vector cimport vector as c_vector +from libcpp.unordered_set cimport unordered_set as c_unordered_set +from libc.stdint cimport int64_t, int32_t + +from pyarrow.includes.libarrow cimport * +from pyarrow.lib cimport (DataType, Field, MemoryPool, RecordBatch, + Schema, check_status, pyarrow_wrap_array, + pyarrow_wrap_data_type, ensure_type, _Weakrefable, + pyarrow_wrap_field) + +from pyarrow.includes.libgandiva cimport ( + CCondition, CGandivaExpression, + CNode, CProjector, CFilter, + CSelectionVector, + _ensure_selection_mode, + CConfiguration, + CConfigurationBuilder, + TreeExprBuilder_MakeExpression, + TreeExprBuilder_MakeFunction, + TreeExprBuilder_MakeBoolLiteral, + TreeExprBuilder_MakeUInt8Literal, + TreeExprBuilder_MakeUInt16Literal, + TreeExprBuilder_MakeUInt32Literal, + TreeExprBuilder_MakeUInt64Literal, + TreeExprBuilder_MakeInt8Literal, + TreeExprBuilder_MakeInt16Literal, + TreeExprBuilder_MakeInt32Literal, + TreeExprBuilder_MakeInt64Literal, + TreeExprBuilder_MakeFloatLiteral, + TreeExprBuilder_MakeDoubleLiteral, + TreeExprBuilder_MakeStringLiteral, + TreeExprBuilder_MakeBinaryLiteral, + TreeExprBuilder_MakeField, + TreeExprBuilder_MakeIf, + TreeExprBuilder_MakeAnd, + TreeExprBuilder_MakeOr, + TreeExprBuilder_MakeCondition, + TreeExprBuilder_MakeInExpressionInt32, + TreeExprBuilder_MakeInExpressionInt64, + TreeExprBuilder_MakeInExpressionTime32, + TreeExprBuilder_MakeInExpressionTime64, + TreeExprBuilder_MakeInExpressionDate32, + TreeExprBuilder_MakeInExpressionDate64, + TreeExprBuilder_MakeInExpressionTimeStamp, + TreeExprBuilder_MakeInExpressionString, + SelectionVector_MakeInt16, + SelectionVector_MakeInt32, + SelectionVector_MakeInt64, + Projector_Make, + Filter_Make, + CFunctionSignature, + GetRegisteredFunctionSignatures) + + +cdef class Node(_Weakrefable): + cdef: + shared_ptr[CNode] node + + def __init__(self): + raise TypeError("Do not call {}'s constructor directly, use the " + "TreeExprBuilder API directly" + .format(self.__class__.__name__)) + + @staticmethod + cdef create(shared_ptr[CNode] node): + cdef Node self = Node.__new__(Node) + self.node = node + return self + + def __str__(self): + return self.node.get().ToString().decode() + + def __repr__(self): + type_format = object.__repr__(self) + return '{0}\n{1}'.format(type_format, str(self)) + + def return_type(self): + return pyarrow_wrap_data_type(self.node.get().return_type()) + + +cdef class Expression(_Weakrefable): + cdef: + shared_ptr[CGandivaExpression] expression + + cdef void init(self, shared_ptr[CGandivaExpression] expression): + self.expression = expression + + def __str__(self): + return self.expression.get().ToString().decode() + + def __repr__(self): + type_format = object.__repr__(self) + return '{0}\n{1}'.format(type_format, str(self)) + + def root(self): + return Node.create(self.expression.get().root()) + + def result(self): + return pyarrow_wrap_field(self.expression.get().result()) + + +cdef class Condition(_Weakrefable): + cdef: + shared_ptr[CCondition] condition + + def __init__(self): + raise TypeError("Do not call {}'s constructor directly, use the " + "TreeExprBuilder API instead" + .format(self.__class__.__name__)) + + @staticmethod + cdef create(shared_ptr[CCondition] condition): + cdef Condition self = Condition.__new__(Condition) + self.condition = condition + return self + + def __str__(self): + return self.condition.get().ToString().decode() + + def __repr__(self): + type_format = object.__repr__(self) + return '{0}\n{1}'.format(type_format, str(self)) + + def root(self): + return Node.create(self.condition.get().root()) + + def result(self): + return pyarrow_wrap_field(self.condition.get().result()) + + +cdef class SelectionVector(_Weakrefable): + cdef: + shared_ptr[CSelectionVector] selection_vector + + def __init__(self): + raise TypeError("Do not call {}'s constructor directly." + .format(self.__class__.__name__)) + + @staticmethod + cdef create(shared_ptr[CSelectionVector] selection_vector): + cdef SelectionVector self = SelectionVector.__new__(SelectionVector) + self.selection_vector = selection_vector + return self + + def to_array(self): + cdef shared_ptr[CArray] result = self.selection_vector.get().ToArray() + return pyarrow_wrap_array(result) + + +cdef class Projector(_Weakrefable): + cdef: + shared_ptr[CProjector] projector + MemoryPool pool + + def __init__(self): + raise TypeError("Do not call {}'s constructor directly, use " + "make_projector instead" + .format(self.__class__.__name__)) + + @staticmethod + cdef create(shared_ptr[CProjector] projector, MemoryPool pool): + cdef Projector self = Projector.__new__(Projector) + self.projector = projector + self.pool = pool + return self + + @property + def llvm_ir(self): + return self.projector.get().DumpIR().decode() + + def evaluate(self, RecordBatch batch, SelectionVector selection=None): + """ + Evaluate the specified record batch and return the arrays at the + filtered positions. + + Parameters + ---------- + batch : pyarrow.RecordBatch + selection : pyarrow.gandiva.SelectionVector + + Returns + ------- + list[pyarrow.Array] + """ + cdef vector[shared_ptr[CArray]] results + if selection is None: + check_status(self.projector.get().Evaluate( + batch.sp_batch.get()[0], self.pool.pool, &results)) + else: + check_status( + self.projector.get().Evaluate( + batch.sp_batch.get()[0], selection.selection_vector.get(), + self.pool.pool, &results)) + cdef shared_ptr[CArray] result + arrays = [] + for result in results: + arrays.append(pyarrow_wrap_array(result)) + return arrays + + +cdef class Filter(_Weakrefable): + cdef: + shared_ptr[CFilter] filter + + def __init__(self): + raise TypeError("Do not call {}'s constructor directly, use " + "make_filter instead" + .format(self.__class__.__name__)) + + @staticmethod + cdef create(shared_ptr[CFilter] filter): + cdef Filter self = Filter.__new__(Filter) + self.filter = filter + return self + + @property + def llvm_ir(self): + return self.filter.get().DumpIR().decode() + + def evaluate(self, RecordBatch batch, MemoryPool pool, dtype='int32'): + """ + Evaluate the specified record batch and return a selection vector. + + Parameters + ---------- + batch : pyarrow.RecordBatch + pool : MemoryPool + dtype : DataType or str, default int32 + + Returns + ------- + pyarrow.gandiva.SelectionVector + """ + cdef: + DataType type = ensure_type(dtype) + shared_ptr[CSelectionVector] selection + + if type.id == _Type_INT16: + check_status(SelectionVector_MakeInt16( + batch.num_rows, pool.pool, &selection)) + elif type.id == _Type_INT32: + check_status(SelectionVector_MakeInt32( + batch.num_rows, pool.pool, &selection)) + elif type.id == _Type_INT64: + check_status(SelectionVector_MakeInt64( + batch.num_rows, pool.pool, &selection)) + else: + raise ValueError("'dtype' of the selection vector should be " + "one of 'int16', 'int32' and 'int64'.") + + check_status(self.filter.get().Evaluate( + batch.sp_batch.get()[0], selection)) + return SelectionVector.create(selection) + + +cdef class TreeExprBuilder(_Weakrefable): + + def make_literal(self, value, dtype): + """ + Create a node on a literal. + + Parameters + ---------- + value : a literal value + dtype : DataType + + Returns + ------- + pyarrow.gandiva.Node + """ + cdef: + DataType type = ensure_type(dtype) + shared_ptr[CNode] r + + if type.id == _Type_BOOL: + r = TreeExprBuilder_MakeBoolLiteral(value) + elif type.id == _Type_UINT8: + r = TreeExprBuilder_MakeUInt8Literal(value) + elif type.id == _Type_UINT16: + r = TreeExprBuilder_MakeUInt16Literal(value) + elif type.id == _Type_UINT32: + r = TreeExprBuilder_MakeUInt32Literal(value) + elif type.id == _Type_UINT64: + r = TreeExprBuilder_MakeUInt64Literal(value) + elif type.id == _Type_INT8: + r = TreeExprBuilder_MakeInt8Literal(value) + elif type.id == _Type_INT16: + r = TreeExprBuilder_MakeInt16Literal(value) + elif type.id == _Type_INT32: + r = TreeExprBuilder_MakeInt32Literal(value) + elif type.id == _Type_INT64: + r = TreeExprBuilder_MakeInt64Literal(value) + elif type.id == _Type_FLOAT: + r = TreeExprBuilder_MakeFloatLiteral(value) + elif type.id == _Type_DOUBLE: + r = TreeExprBuilder_MakeDoubleLiteral(value) + elif type.id == _Type_STRING: + r = TreeExprBuilder_MakeStringLiteral(value.encode('UTF-8')) + elif type.id == _Type_BINARY: + r = TreeExprBuilder_MakeBinaryLiteral(value) + else: + raise TypeError("Didn't recognize dtype " + str(dtype)) + + return Node.create(r) + + def make_expression(self, Node root_node not None, + Field return_field not None): + """ + Create an expression with the specified root_node, + and the result written to result_field. + + Parameters + ---------- + root_node : pyarrow.gandiva.Node + return_field : pyarrow.Field + + Returns + ------- + pyarrow.gandiva.Expression + """ + cdef shared_ptr[CGandivaExpression] r = TreeExprBuilder_MakeExpression( + root_node.node, return_field.sp_field) + cdef Expression expression = Expression() + expression.init(r) + return expression + + def make_function(self, name, children, DataType return_type): + """ + Create a node with a function. + + Parameters + ---------- + name : str + children : pyarrow.gandiva.NodeVector + return_type : DataType + + Returns + ------- + pyarrow.gandiva.Node + """ + cdef c_vector[shared_ptr[CNode]] c_children + cdef Node child + for child in children: + if child is None: + raise TypeError("Child nodes must not be None") + c_children.push_back(child.node) + cdef shared_ptr[CNode] r = TreeExprBuilder_MakeFunction( + name.encode(), c_children, return_type.sp_type) + return Node.create(r) + + def make_field(self, Field field not None): + """ + Create a node with an Arrow field. + + Parameters + ---------- + field : pyarrow.Field + + Returns + ------- + pyarrow.gandiva.Node + """ + cdef shared_ptr[CNode] r = TreeExprBuilder_MakeField(field.sp_field) + return Node.create(r) + + def make_if(self, Node condition not None, Node this_node not None, + Node else_node not None, DataType return_type not None): + """ + Create a node with an if-else expression. + + Parameters + ---------- + condition : pyarrow.gandiva.Node + this_node : pyarrow.gandiva.Node + else_node : pyarrow.gandiva.Node + return_type : DataType + + Returns + ------- + pyarrow.gandiva.Node + """ + cdef shared_ptr[CNode] r = TreeExprBuilder_MakeIf( + condition.node, this_node.node, else_node.node, + return_type.sp_type) + return Node.create(r) + + def make_and(self, children): + """ + Create a Node with a boolean AND expression. + + Parameters + ---------- + children : list[pyarrow.gandiva.Node] + + Returns + ------- + pyarrow.gandiva.Node + """ + cdef c_vector[shared_ptr[CNode]] c_children + cdef Node child + for child in children: + if child is None: + raise TypeError("Child nodes must not be None") + c_children.push_back(child.node) + cdef shared_ptr[CNode] r = TreeExprBuilder_MakeAnd(c_children) + return Node.create(r) + + def make_or(self, children): + """ + Create a Node with a boolean OR expression. + + Parameters + ---------- + children : list[pyarrow.gandiva.Node] + + Returns + ------- + pyarrow.gandiva.Node + """ + cdef c_vector[shared_ptr[CNode]] c_children + cdef Node child + for child in children: + if child is None: + raise TypeError("Child nodes must not be None") + c_children.push_back(child.node) + cdef shared_ptr[CNode] r = TreeExprBuilder_MakeOr(c_children) + return Node.create(r) + + def _make_in_expression_int32(self, Node node not None, values): + cdef shared_ptr[CNode] r + cdef c_unordered_set[int32_t] c_values + cdef int32_t v + for v in values: + c_values.insert(v) + r = TreeExprBuilder_MakeInExpressionInt32(node.node, c_values) + return Node.create(r) + + def _make_in_expression_int64(self, Node node not None, values): + cdef shared_ptr[CNode] r + cdef c_unordered_set[int64_t] c_values + cdef int64_t v + for v in values: + c_values.insert(v) + r = TreeExprBuilder_MakeInExpressionInt64(node.node, c_values) + return Node.create(r) + + def _make_in_expression_time32(self, Node node not None, values): + cdef shared_ptr[CNode] r + cdef c_unordered_set[int32_t] c_values + cdef int32_t v + for v in values: + c_values.insert(v) + r = TreeExprBuilder_MakeInExpressionTime32(node.node, c_values) + return Node.create(r) + + def _make_in_expression_time64(self, Node node not None, values): + cdef shared_ptr[CNode] r + cdef c_unordered_set[int64_t] c_values + cdef int64_t v + for v in values: + c_values.insert(v) + r = TreeExprBuilder_MakeInExpressionTime64(node.node, c_values) + return Node.create(r) + + def _make_in_expression_date32(self, Node node not None, values): + cdef shared_ptr[CNode] r + cdef c_unordered_set[int32_t] c_values + cdef int32_t v + for v in values: + c_values.insert(v) + r = TreeExprBuilder_MakeInExpressionDate32(node.node, c_values) + return Node.create(r) + + def _make_in_expression_date64(self, Node node not None, values): + cdef shared_ptr[CNode] r + cdef c_unordered_set[int64_t] c_values + cdef int64_t v + for v in values: + c_values.insert(v) + r = TreeExprBuilder_MakeInExpressionDate64(node.node, c_values) + return Node.create(r) + + def _make_in_expression_timestamp(self, Node node not None, values): + cdef shared_ptr[CNode] r + cdef c_unordered_set[int64_t] c_values + cdef int64_t v + for v in values: + c_values.insert(v) + r = TreeExprBuilder_MakeInExpressionTimeStamp(node.node, c_values) + return Node.create(r) + + def _make_in_expression_binary(self, Node node not None, values): + cdef shared_ptr[CNode] r + cdef c_unordered_set[c_string] c_values + cdef c_string v + for v in values: + c_values.insert(v) + r = TreeExprBuilder_MakeInExpressionString(node.node, c_values) + return Node.create(r) + + def _make_in_expression_string(self, Node node not None, values): + cdef shared_ptr[CNode] r + cdef c_unordered_set[c_string] c_values + cdef c_string _v + for v in values: + _v = v.encode('UTF-8') + c_values.insert(_v) + r = TreeExprBuilder_MakeInExpressionString(node.node, c_values) + return Node.create(r) + + def make_in_expression(self, Node node not None, values, dtype): + """ + Create a Node with an IN expression. + + Parameters + ---------- + node : pyarrow.gandiva.Node + values : iterable + dtype : DataType + + Returns + ------- + pyarrow.gandiva.Node + """ + cdef DataType type = ensure_type(dtype) + + if type.id == _Type_INT32: + return self._make_in_expression_int32(node, values) + elif type.id == _Type_INT64: + return self._make_in_expression_int64(node, values) + elif type.id == _Type_TIME32: + return self._make_in_expression_time32(node, values) + elif type.id == _Type_TIME64: + return self._make_in_expression_time64(node, values) + elif type.id == _Type_TIMESTAMP: + return self._make_in_expression_timestamp(node, values) + elif type.id == _Type_DATE32: + return self._make_in_expression_date32(node, values) + elif type.id == _Type_DATE64: + return self._make_in_expression_date64(node, values) + elif type.id == _Type_BINARY: + return self._make_in_expression_binary(node, values) + elif type.id == _Type_STRING: + return self._make_in_expression_string(node, values) + else: + raise TypeError("Data type " + str(dtype) + " not supported.") + + def make_condition(self, Node condition not None): + """ + Create a condition with the specified node. + + Parameters + ---------- + condition : pyarrow.gandiva.Node + + Returns + ------- + pyarrow.gandiva.Condition + """ + cdef shared_ptr[CCondition] r = TreeExprBuilder_MakeCondition( + condition.node) + return Condition.create(r) + +cdef class Configuration(_Weakrefable): + cdef: + shared_ptr[CConfiguration] configuration + + def __cinit__(self, bint optimize=True, bint dump_ir=False): + """ + Initialize the configuration with specified options. + + Parameters + ---------- + optimize : bool, default True + Whether to enable optimizations. + dump_ir : bool, default False + Whether to dump LLVM IR. + """ + self.configuration = CConfigurationBuilder().build() + self.configuration.get().set_optimize(optimize) + self.configuration.get().set_dump_ir(dump_ir) + + @staticmethod + cdef create(shared_ptr[CConfiguration] configuration): + """ + Create a Configuration instance from an existing CConfiguration pointer. + + Parameters + ---------- + configuration : shared_ptr[CConfiguration] + Existing CConfiguration pointer. + + Returns + ------- + Configuration instance + """ + cdef Configuration self = Configuration.__new__(Configuration) + self.configuration = configuration + return self + + +cpdef make_projector(Schema schema, children, MemoryPool pool, + str selection_mode="NONE", + Configuration configuration=None): + """ + Construct a projection using expressions. + + A projector is built for a specific schema and vector of expressions. + Once the projector is built, it can be used to evaluate many row batches. + + Parameters + ---------- + schema : pyarrow.Schema + Schema for the record batches, and the expressions. + children : list[pyarrow.gandiva.Expression] + List of projectable expression objects. + pool : pyarrow.MemoryPool + Memory pool used to allocate output arrays. + selection_mode : str, default "NONE" + Possible values are NONE, UINT16, UINT32, UINT64. + configuration : pyarrow.gandiva.Configuration, default None + Configuration for the projector. + + Returns + ------- + Projector instance + """ + cdef: + Expression child + c_vector[shared_ptr[CGandivaExpression]] c_children + shared_ptr[CProjector] result + + if configuration is None: + configuration = Configuration() + + for child in children: + if child is None: + raise TypeError("Expressions must not be None") + c_children.push_back(child.expression) + + check_status( + Projector_Make(schema.sp_schema, c_children, + _ensure_selection_mode(selection_mode), + configuration.configuration, + &result)) + return Projector.create(result, pool) + + +cpdef make_filter(Schema schema, Condition condition, + Configuration configuration=None): + """ + Construct a filter based on a condition. + + A filter is built for a specific schema and condition. Once the filter is + built, it can be used to evaluate many row batches. + + Parameters + ---------- + schema : pyarrow.Schema + Schema for the record batches, and the condition. + condition : pyarrow.gandiva.Condition + Filter condition. + configuration : pyarrow.gandiva.Configuration, default None + Configuration for the filter. + + Returns + ------- + Filter instance + """ + cdef shared_ptr[CFilter] result + if condition is None: + raise TypeError("Condition must not be None") + + if configuration is None: + configuration = Configuration() + + check_status( + Filter_Make(schema.sp_schema, condition.condition, configuration.configuration, &result)) + return Filter.create(result) + + +cdef class FunctionSignature(_Weakrefable): + """ + Signature of a Gandiva function including name, parameter types + and return type. + """ + + cdef: + shared_ptr[CFunctionSignature] signature + + def __init__(self): + raise TypeError("Do not call {}'s constructor directly." + .format(self.__class__.__name__)) + + @staticmethod + cdef create(shared_ptr[CFunctionSignature] signature): + cdef FunctionSignature self = FunctionSignature.__new__( + FunctionSignature) + self.signature = signature + return self + + def return_type(self): + return pyarrow_wrap_data_type(self.signature.get().ret_type()) + + def param_types(self): + result = [] + cdef vector[shared_ptr[CDataType]] types = \ + self.signature.get().param_types() + for t in types: + result.append(pyarrow_wrap_data_type(t)) + return result + + def name(self): + return self.signature.get().base_name().decode() + + def __repr__(self): + signature = self.signature.get().ToString().decode() + return "FunctionSignature(" + signature + ")" + + +def get_registered_function_signatures(): + """ + Return the function in Gandiva's ExpressionRegistry. + + Returns + ------- + registry: a list of registered function signatures + """ + results = [] + + cdef vector[shared_ptr[CFunctionSignature]] signatures = \ + GetRegisteredFunctionSignatures() + + for signature in signatures: + results.append(FunctionSignature.create(signature)) + + return results diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/io.pxi b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/io.pxi new file mode 100644 index 0000000000000000000000000000000000000000..b3de15067fbfae196ed5f9490301d10c469e1df3 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/io.pxi @@ -0,0 +1,2919 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Cython wrappers for IO interfaces defined in arrow::io and messaging in +# arrow::ipc + +from libc.stdlib cimport malloc, free + +import codecs +import pickle +import re +import sys +import threading +import time +import warnings +from io import BufferedIOBase, IOBase, TextIOBase, UnsupportedOperation +from queue import Queue, Empty as QueueEmpty + +from pyarrow.lib cimport check_status, HaveLibHdfs +from pyarrow.util import _is_path_like, _stringify_path + + +# 64K +DEFAULT_BUFFER_SIZE = 2 ** 16 + + +cdef extern from "Python.h": + # To let us get a PyObject* and avoid Cython auto-ref-counting + PyObject* PyBytes_FromStringAndSizeNative" PyBytes_FromStringAndSize"( + char *v, Py_ssize_t len) except NULL + + # Workaround https://github.com/cython/cython/issues/4707 + bytearray PyByteArray_FromStringAndSize(char *string, Py_ssize_t len) + + +def have_libhdfs(): + """ + Return true if HDFS (HadoopFileSystem) library is set up correctly. + """ + try: + with nogil: + check_status(HaveLibHdfs()) + return True + except Exception: + return False + + +def io_thread_count(): + """ + Return the number of threads to use for I/O operations. + + Many operations, such as scanning a dataset, will implicitly make + use of this pool. The number of threads is set to a fixed value at + startup. It can be modified at runtime by calling + :func:`set_io_thread_count()`. + + See Also + -------- + set_io_thread_count : Modify the size of this pool. + cpu_count : The analogous function for the CPU thread pool. + """ + return GetIOThreadPoolCapacity() + + +def set_io_thread_count(int count): + """ + Set the number of threads to use for I/O operations. + + Many operations, such as scanning a dataset, will implicitly make + use of this pool. + + Parameters + ---------- + count : int + The max number of threads that may be used for I/O. + Must be positive. + + See Also + -------- + io_thread_count : Get the size of this pool. + set_cpu_count : The analogous function for the CPU thread pool. + """ + if count < 1: + raise ValueError("IO thread count must be strictly positive") + check_status(SetIOThreadPoolCapacity(count)) + + +cdef class NativeFile(_Weakrefable): + """ + The base class for all Arrow streams. + + Streams are either readable, writable, or both. + They optionally support seeking. + + While this class exposes methods to read or write data from Python, the + primary intent of using a Arrow stream is to pass it to other Arrow + facilities that will make use of it, such as Arrow IPC routines. + + Be aware that there are subtle differences with regular Python files, + e.g. destroying a writable Arrow stream without closing it explicitly + will not flush any pending data. + """ + + # Default chunk size for chunked reads. + # Use a large enough value for networked filesystems. + _default_chunk_size = 256 * 1024 + + def __cinit__(self): + self.own_file = False + self.is_readable = False + self.is_writable = False + self.is_seekable = False + self._is_appending = False + + def __dealloc__(self): + if self.own_file: + self.close() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, tb): + self.close() + + def __repr__(self): + name = f"pyarrow.{self.__class__.__name__}" + return (f"<{name} " + f"closed={self.closed} " + f"own_file={self.own_file} " + f"is_seekable={self.is_seekable} " + f"is_writable={self.is_writable} " + f"is_readable={self.is_readable}>") + + @property + def mode(self): + """ + The file mode. Currently instances of NativeFile may support: + + * rb: binary read + * wb: binary write + * rb+: binary read and write + * ab: binary append + """ + # Emulate built-in file modes + if self.is_readable and self.is_writable: + return 'rb+' + elif self.is_readable: + return 'rb' + elif self.is_writable and self._is_appending: + return 'ab' + elif self.is_writable: + return 'wb' + else: + raise ValueError('File object is malformed, has no mode') + + def readable(self): + self._assert_open() + return self.is_readable + + def writable(self): + self._assert_open() + return self.is_writable + + def seekable(self): + self._assert_open() + return self.is_seekable + + def isatty(self): + self._assert_open() + return False + + def fileno(self): + """ + NOT IMPLEMENTED + """ + raise UnsupportedOperation() + + @property + def closed(self): + if self.is_readable: + return self.input_stream.get().closed() + elif self.is_writable: + return self.output_stream.get().closed() + else: + return True + + def close(self): + if not self.closed: + with nogil: + if self.is_readable: + check_status(self.input_stream.get().Close()) + else: + check_status(self.output_stream.get().Close()) + + cdef set_random_access_file(self, shared_ptr[CRandomAccessFile] handle): + self.input_stream = handle + self.random_access = handle + self.is_seekable = True + + cdef set_input_stream(self, shared_ptr[CInputStream] handle): + self.input_stream = handle + self.random_access.reset() + self.is_seekable = False + + cdef set_output_stream(self, shared_ptr[COutputStream] handle): + self.output_stream = handle + + cdef shared_ptr[CRandomAccessFile] get_random_access_file(self) except *: + self._assert_readable() + self._assert_seekable() + return self.random_access + + cdef shared_ptr[CInputStream] get_input_stream(self) except *: + self._assert_readable() + return self.input_stream + + cdef shared_ptr[COutputStream] get_output_stream(self) except *: + self._assert_writable() + return self.output_stream + + def _assert_open(self): + if self.closed: + raise ValueError("I/O operation on closed file") + + def _assert_readable(self): + self._assert_open() + if not self.is_readable: + # XXX UnsupportedOperation + raise IOError("only valid on readable files") + + def _assert_writable(self): + self._assert_open() + if not self.is_writable: + raise IOError("only valid on writable files") + + def _assert_seekable(self): + self._assert_open() + if not self.is_seekable: + raise IOError("only valid on seekable files") + + def size(self): + """ + Return file size + """ + cdef int64_t size + + handle = self.get_random_access_file() + with nogil: + size = GetResultValue(handle.get().GetSize()) + + return size + + def metadata(self): + """ + Return file metadata + """ + cdef: + shared_ptr[const CKeyValueMetadata] c_metadata + + handle = self.get_input_stream() + with nogil: + c_metadata = GetResultValue(handle.get().ReadMetadata()) + + metadata = {} + if c_metadata.get() != nullptr: + for i in range(c_metadata.get().size()): + metadata[frombytes(c_metadata.get().key(i))] = \ + c_metadata.get().value(i) + return metadata + + def tell(self): + """ + Return current stream position + """ + cdef int64_t position + + if self.is_readable: + rd_handle = self.get_random_access_file() + with nogil: + position = GetResultValue(rd_handle.get().Tell()) + else: + wr_handle = self.get_output_stream() + with nogil: + position = GetResultValue(wr_handle.get().Tell()) + + return position + + def seek(self, int64_t position, int whence=0): + """ + Change current file stream position + + Parameters + ---------- + position : int + Byte offset, interpreted relative to value of whence argument + whence : int, default 0 + Point of reference for seek offset + + Notes + ----- + Values of whence: + * 0 -- start of stream (the default); offset should be zero or positive + * 1 -- current stream position; offset may be negative + * 2 -- end of stream; offset is usually negative + + Returns + ------- + int + The new absolute stream position. + """ + cdef int64_t offset + handle = self.get_random_access_file() + + with nogil: + if whence == 0: + offset = position + elif whence == 1: + offset = GetResultValue(handle.get().Tell()) + offset = offset + position + elif whence == 2: + offset = GetResultValue(handle.get().GetSize()) + offset = offset + position + else: + with gil: + raise ValueError("Invalid value of whence: {0}" + .format(whence)) + check_status(handle.get().Seek(offset)) + + return self.tell() + + def flush(self): + """ + Flush the stream, if applicable. + + An error is raised if stream is not writable. + """ + self._assert_open() + # For IOBase compatibility, flush() on an input stream is a no-op + if self.is_writable: + handle = self.get_output_stream() + with nogil: + check_status(handle.get().Flush()) + + def write(self, data): + """ + Write data to the file. + + Parameters + ---------- + data : bytes-like object or exporter of buffer protocol + + Returns + ------- + int + nbytes: number of bytes written + """ + self._assert_writable() + handle = self.get_output_stream() + + cdef shared_ptr[CBuffer] buf = as_c_buffer(data) + + with nogil: + check_status(handle.get().WriteBuffer(buf)) + return buf.get().size() + + def read(self, nbytes=None): + """ + Read and return up to n bytes. + + If *nbytes* is None, then the entire remaining file contents are read. + + Parameters + ---------- + nbytes : int, default None + + Returns + ------- + data : bytes + """ + cdef: + int64_t c_nbytes + int64_t bytes_read = 0 + PyObject* obj + + if nbytes is None: + if not self.is_seekable: + # Cannot get file size => read chunkwise + bs = self._default_chunk_size + chunks = [] + while True: + chunk = self.read(bs) + if not chunk: + break + chunks.append(chunk) + return b"".join(chunks) + + c_nbytes = self.size() - self.tell() + else: + c_nbytes = nbytes + + handle = self.get_input_stream() + + # Allocate empty write space + obj = PyBytes_FromStringAndSizeNative(NULL, c_nbytes) + + cdef uint8_t* buf = cp.PyBytes_AS_STRING( obj) + with nogil: + bytes_read = GetResultValue(handle.get().Read(c_nbytes, buf)) + + if bytes_read < c_nbytes: + cp._PyBytes_Resize(&obj, bytes_read) + + return PyObject_to_object(obj) + + def get_stream(self, file_offset, nbytes): + """ + Return an input stream that reads a file segment independent of the + state of the file. + + Allows reading portions of a random access file as an input stream + without interfering with each other. + + Parameters + ---------- + file_offset : int + nbytes : int + + Returns + ------- + stream : NativeFile + """ + cdef: + shared_ptr[CInputStream] data + int64_t c_file_offset + int64_t c_nbytes + + c_file_offset = file_offset + c_nbytes = nbytes + + handle = self.get_random_access_file() + + data = GetResultValue( + CRandomAccessFile.GetStream(handle, c_file_offset, c_nbytes)) + + stream = NativeFile() + stream.set_input_stream(data) + stream.is_readable = True + + return stream + + def read_at(self, nbytes, offset): + """ + Read indicated number of bytes at offset from the file + + Parameters + ---------- + nbytes : int + offset : int + + Returns + ------- + data : bytes + """ + cdef: + int64_t c_nbytes + int64_t c_offset + int64_t bytes_read = 0 + PyObject* obj + + c_nbytes = nbytes + + c_offset = offset + + handle = self.get_random_access_file() + + # Allocate empty write space + obj = PyBytes_FromStringAndSizeNative(NULL, c_nbytes) + + cdef uint8_t* buf = cp.PyBytes_AS_STRING( obj) + with nogil: + bytes_read = GetResultValue(handle.get(). + ReadAt(c_offset, c_nbytes, buf)) + + if bytes_read < c_nbytes: + cp._PyBytes_Resize(&obj, bytes_read) + + return PyObject_to_object(obj) + + def read1(self, nbytes=None): + """Read and return up to n bytes. + + Unlike read(), if *nbytes* is None then a chunk is read, not the + entire file. + + Parameters + ---------- + nbytes : int, default None + The maximum number of bytes to read. + + Returns + ------- + data : bytes + """ + if nbytes is None: + # The expectation when passing `nbytes=None` is not to read the + # entire file but to issue a single underlying read call up to + # a reasonable size (the use case being to read a bufferable + # amount of bytes, such as with io.TextIOWrapper). + nbytes = self._default_chunk_size + return self.read(nbytes) + + def readall(self): + return self.read() + + def readinto(self, b): + """ + Read into the supplied buffer + + Parameters + ---------- + b : buffer-like object + A writable buffer object (such as a bytearray). + + Returns + ------- + written : int + number of bytes written + """ + + cdef: + int64_t bytes_read + uint8_t* buf + Buffer py_buf + int64_t buf_len + + handle = self.get_input_stream() + + py_buf = py_buffer(b) + buf_len = py_buf.size + buf = py_buf.buffer.get().mutable_data() + + with nogil: + bytes_read = GetResultValue(handle.get().Read(buf_len, buf)) + + return bytes_read + + def readline(self, size=None): + """NOT IMPLEMENTED. Read and return a line of bytes from the file. + + If size is specified, read at most size bytes. + + Line terminator is always b"\\n". + + Parameters + ---------- + size : int + maximum number of bytes read + """ + raise UnsupportedOperation() + + def readlines(self, hint=None): + """NOT IMPLEMENTED. Read lines of the file + + Parameters + ---------- + hint : int + maximum number of bytes read until we stop + """ + raise UnsupportedOperation() + + def __iter__(self): + self._assert_readable() + return self + + def __next__(self): + line = self.readline() + if not line: + raise StopIteration + return line + + def read_buffer(self, nbytes=None): + """ + Read from buffer. + + Parameters + ---------- + nbytes : int, optional + maximum number of bytes read + """ + cdef: + int64_t c_nbytes + int64_t bytes_read = 0 + shared_ptr[CBuffer] output + + handle = self.get_input_stream() + + if nbytes is None: + if not self.is_seekable: + # Cannot get file size => read chunkwise + return py_buffer(self.read()) + c_nbytes = self.size() - self.tell() + else: + c_nbytes = nbytes + + with nogil: + output = GetResultValue(handle.get().ReadBuffer(c_nbytes)) + + return pyarrow_wrap_buffer(output) + + def truncate(self): + """ + NOT IMPLEMENTED + """ + raise UnsupportedOperation() + + def writelines(self, lines): + """ + Write lines to the file. + + Parameters + ---------- + lines : iterable + Iterable of bytes-like objects or exporters of buffer protocol + """ + self._assert_writable() + + for line in lines: + self.write(line) + + def download(self, stream_or_path, buffer_size=None): + """ + Read this file completely to a local path or destination stream. + + This method first seeks to the beginning of the file. + + Parameters + ---------- + stream_or_path : str or file-like object + If a string, a local file path to write to; otherwise, + should be a writable stream. + buffer_size : int, optional + The buffer size to use for data transfers. + """ + cdef: + int64_t bytes_read = 0 + uint8_t* buf + + if not is_threading_enabled(): + return self._download_nothreads(stream_or_path, buffer_size) + + handle = self.get_input_stream() + + buffer_size = buffer_size or DEFAULT_BUFFER_SIZE + + write_queue = Queue(50) + + if not hasattr(stream_or_path, 'read'): + stream = open(stream_or_path, 'wb') + + def cleanup(): + stream.close() + else: + stream = stream_or_path + + def cleanup(): + pass + + done = False + exc_info = None + + def bg_write(): + try: + while not done or write_queue.qsize() > 0: + try: + buf = write_queue.get(timeout=0.01) + except QueueEmpty: + continue + stream.write(buf) + except Exception as e: + exc_info = sys.exc_info() + finally: + cleanup() + + self.seek(0) + + writer_thread = threading.Thread(target=bg_write) + + # This isn't ideal -- PyBytes_FromStringAndSize copies the data from + # the passed buffer, so it's hard for us to avoid doubling the memory + buf = malloc(buffer_size) + if buf == NULL: + raise MemoryError("Failed to allocate {0} bytes" + .format(buffer_size)) + + writer_thread.start() + + cdef int64_t total_bytes = 0 + cdef int32_t c_buffer_size = buffer_size + + try: + while True: + with nogil: + bytes_read = GetResultValue( + handle.get().Read(c_buffer_size, buf)) + + total_bytes += bytes_read + + # EOF + if bytes_read == 0: + break + + pybuf = cp.PyBytes_FromStringAndSize(buf, + bytes_read) + + if writer_thread.is_alive(): + while write_queue.full(): + time.sleep(0.01) + else: + break + + write_queue.put_nowait(pybuf) + finally: + free(buf) + done = True + + writer_thread.join() + if exc_info is not None: + raise exc_info[0], exc_info[1], exc_info[2] + + def _download_nothreads(self, stream_or_path, buffer_size=None): + """ + Internal method to do a download without separate threads, queues etc. + Called by download above if is_threading_enabled() == False + """ + cdef: + int64_t bytes_read = 0 + uint8_t* buf + + handle = self.get_input_stream() + + buffer_size = buffer_size or DEFAULT_BUFFER_SIZE + + if not hasattr(stream_or_path, 'read'): + stream = open(stream_or_path, 'wb') + + def cleanup(): + stream.close() + else: + stream = stream_or_path + + def cleanup(): + pass + + self.seek(0) + + # This isn't ideal -- PyBytes_FromStringAndSize copies the data from + # the passed buffer, so it's hard for us to avoid doubling the memory + buf = malloc(buffer_size) + if buf == NULL: + raise MemoryError("Failed to allocate {0} bytes" + .format(buffer_size)) + + cdef int64_t total_bytes = 0 + cdef int32_t c_buffer_size = buffer_size + + try: + while True: + with nogil: + bytes_read = GetResultValue( + handle.get().Read(c_buffer_size, buf)) + + total_bytes += bytes_read + + # EOF + if bytes_read == 0: + break + + pybuf = cp.PyBytes_FromStringAndSize(buf, + bytes_read) + + # no background thread - write on main thread + stream.write(pybuf) + finally: + free(buf) + cleanup() + + def upload(self, stream, buffer_size=None): + """ + Write from a source stream to this file. + + Parameters + ---------- + stream : file-like object + Source stream to pipe to this file. + buffer_size : int, optional + The buffer size to use for data transfers. + """ + if not is_threading_enabled(): + return self._upload_nothreads(stream, buffer_size) + + write_queue = Queue(50) + self._assert_writable() + + buffer_size = buffer_size or DEFAULT_BUFFER_SIZE + + done = False + exc_info = None + + def bg_write(): + try: + while not done or write_queue.qsize() > 0: + try: + buf = write_queue.get(timeout=0.01) + except QueueEmpty: + continue + + self.write(buf) + + except Exception as e: + exc_info = sys.exc_info() + + writer_thread = threading.Thread(target=bg_write) + writer_thread.start() + + try: + while True: + buf = stream.read(buffer_size) + if not buf: + break + + if writer_thread.is_alive(): + while write_queue.full(): + time.sleep(0.01) + else: + break + + write_queue.put_nowait(buf) + finally: + done = True + + writer_thread.join() + if exc_info is not None: + raise exc_info[0], exc_info[1], exc_info[2] + + def _upload_nothreads(self, stream, buffer_size=None): + """ + Internal method to do an upload without separate threads, queues etc. + Called by upload above if is_threading_enabled() == False + """ + self._assert_writable() + + buffer_size = buffer_size or DEFAULT_BUFFER_SIZE + + while True: + buf = stream.read(buffer_size) + if not buf: + break + + # no threading - just write + self.write(buf) + + +BufferedIOBase.register(NativeFile) + +# ---------------------------------------------------------------------- +# Python file-like objects + + +cdef class PythonFile(NativeFile): + """ + A stream backed by a Python file object. + + This class allows using Python file objects with arbitrary Arrow + functions, including functions written in another language than Python. + + As a downside, there is a non-zero redirection cost in translating + Arrow stream calls to Python method calls. Furthermore, Python's + Global Interpreter Lock may limit parallelism in some situations. + + Examples + -------- + >>> import io + >>> import pyarrow as pa + >>> pa.PythonFile(io.BytesIO()) + + + Create a stream for writing: + + >>> buf = io.BytesIO() + >>> f = pa.PythonFile(buf, mode = 'w') + >>> f.writable() + True + >>> f.write(b'PythonFile') + 10 + >>> buf.getvalue() + b'PythonFile' + >>> f.close() + >>> f + + + Create a stream for reading: + + >>> buf = io.BytesIO(b'PythonFile') + >>> f = pa.PythonFile(buf, mode = 'r') + >>> f.mode + 'rb' + >>> f.read() + b'PythonFile' + >>> f + + >>> f.close() + >>> f + + """ + cdef: + object handle + + def __cinit__(self, handle, mode=None): + self.handle = handle + + if mode is None: + try: + inferred_mode = handle.mode + except AttributeError: + # Not all file-like objects have a mode attribute + # (e.g. BytesIO) + try: + inferred_mode = 'w' if handle.writable() else 'r' + except AttributeError: + raise ValueError("could not infer open mode for file-like " + "object %r, please pass it explicitly" + % (handle,)) + else: + inferred_mode = mode + + if inferred_mode.startswith('w'): + kind = 'w' + elif inferred_mode.startswith('r'): + kind = 'r' + else: + raise ValueError('Invalid file mode: {0}'.format(mode)) + + # If mode was given, check it matches the given file + if mode is not None: + if isinstance(handle, IOBase): + # Python 3 IO object + if kind == 'r': + if not handle.readable(): + raise TypeError("readable file expected") + else: + if not handle.writable(): + raise TypeError("writable file expected") + # (other duck-typed file-like objects are possible) + + # If possible, check the file is a binary file + if isinstance(handle, TextIOBase): + raise TypeError("binary file expected, got text file") + + if kind == 'r': + self.set_random_access_file( + shared_ptr[CRandomAccessFile](new PyReadableFile(handle))) + self.is_readable = True + else: + self.set_output_stream( + shared_ptr[COutputStream](new PyOutputStream(handle))) + self.is_writable = True + + def truncate(self, pos=None): + """ + Parameters + ---------- + pos : int, optional + """ + self.handle.truncate(pos) + + def readline(self, size=None): + """ + Read and return a line of bytes from the file. + + If size is specified, read at most size bytes. + + Parameters + ---------- + size : int + Maximum number of bytes read + """ + return self.handle.readline(size) + + def readlines(self, hint=None): + """ + Read lines of the file. + + Parameters + ---------- + hint : int + Maximum number of bytes read until we stop + """ + return self.handle.readlines(hint) + + +cdef class MemoryMappedFile(NativeFile): + """ + A stream that represents a memory-mapped file. + + Supports 'r', 'r+', 'w' modes. + + Examples + -------- + Create a new file with memory map: + + >>> import pyarrow as pa + >>> mmap = pa.create_memory_map('example_mmap.dat', 10) + >>> mmap + + >>> mmap.close() + + Open an existing file with memory map: + + >>> with pa.memory_map('example_mmap.dat') as mmap: + ... mmap + ... + + """ + cdef: + shared_ptr[CMemoryMappedFile] handle + object path + + @staticmethod + def create(path, size): + """ + Create a MemoryMappedFile + + Parameters + ---------- + path : str + Where to create the file. + size : int + Size of the memory mapped file. + """ + cdef: + shared_ptr[CMemoryMappedFile] handle + c_string c_path = encode_file_path(path) + int64_t c_size = size + + with nogil: + handle = GetResultValue(CMemoryMappedFile.Create(c_path, c_size)) + + cdef MemoryMappedFile result = MemoryMappedFile() + result.path = path + result.is_readable = True + result.is_writable = True + result.set_output_stream( handle) + result.set_random_access_file( handle) + result.handle = handle + + return result + + def _open(self, path, mode='r'): + self.path = path + + cdef: + FileMode c_mode + shared_ptr[CMemoryMappedFile] handle + c_string c_path = encode_file_path(path) + + if mode in ('r', 'rb'): + c_mode = FileMode_READ + self.is_readable = True + elif mode in ('w', 'wb'): + c_mode = FileMode_WRITE + self.is_writable = True + elif mode in ('r+', 'r+b', 'rb+'): + c_mode = FileMode_READWRITE + self.is_readable = True + self.is_writable = True + else: + raise ValueError('Invalid file mode: {0}'.format(mode)) + + with nogil: + handle = GetResultValue(CMemoryMappedFile.Open(c_path, c_mode)) + + self.set_output_stream( handle) + self.set_random_access_file( handle) + self.handle = handle + + def resize(self, new_size): + """ + Resize the map and underlying file. + + Parameters + ---------- + new_size : new size in bytes + """ + check_status(self.handle.get().Resize(new_size)) + + def fileno(self): + self._assert_open() + return self.handle.get().file_descriptor() + + +def memory_map(path, mode='r'): + """ + Open memory map at file path. Size of the memory map cannot change. + + Parameters + ---------- + path : str + mode : {'r', 'r+', 'w'}, default 'r' + Whether the file is opened for reading ('r'), writing ('w') + or both ('r+'). + + Returns + ------- + mmap : MemoryMappedFile + + Examples + -------- + Reading from a memory map without any memory allocation or copying: + + >>> import pyarrow as pa + >>> with pa.output_stream('example_mmap.txt') as stream: + ... stream.write(b'Constructing a buffer referencing the mapped memory') + ... + 51 + >>> with pa.memory_map('example_mmap.txt') as mmap: + ... mmap.read_at(6,45) + ... + b'memory' + """ + _check_is_file(path) + + cdef MemoryMappedFile mmap = MemoryMappedFile() + mmap._open(path, mode) + return mmap + + +cdef _check_is_file(path): + if os.path.isdir(path): + raise IOError("Expected file path, but {0} is a directory" + .format(path)) + + +def create_memory_map(path, size): + """ + Create a file of the given size and memory-map it. + + Parameters + ---------- + path : str + The file path to create, on the local filesystem. + size : int + The file size to create. + + Returns + ------- + mmap : MemoryMappedFile + + Examples + -------- + Create a file with a memory map: + + >>> import pyarrow as pa + >>> with pa.create_memory_map('example_mmap_create.dat', 27) as mmap: + ... mmap.write(b'Create a memory-mapped file') + ... mmap.read_at(10, 9) + ... + 27 + b'memory-map' + """ + return MemoryMappedFile.create(path, size) + + +cdef class OSFile(NativeFile): + """ + A stream backed by a regular file descriptor. + + Examples + -------- + Create a new file to write to: + + >>> import pyarrow as pa + >>> with pa.OSFile('example_osfile.arrow', mode='w') as f: + ... f.writable() + ... f.write(b'OSFile') + ... f.seekable() + ... + True + 6 + False + + Open the file to read: + + >>> with pa.OSFile('example_osfile.arrow', mode='r') as f: + ... f.mode + ... f.read() + ... + 'rb' + b'OSFile' + + Open the file to append: + + >>> with pa.OSFile('example_osfile.arrow', mode='ab') as f: + ... f.mode + ... f.write(b' is super!') + ... + 'ab' + 10 + >>> with pa.OSFile('example_osfile.arrow') as f: + ... f.read() + ... + b'OSFile is super!' + + Inspect created OSFile: + + >>> pa.OSFile('example_osfile.arrow') + + """ + cdef: + object path + + def __cinit__(self, path, mode='r', MemoryPool memory_pool=None): + _check_is_file(path) + self.path = path + + cdef: + FileMode c_mode + shared_ptr[Readable] handle + c_string c_path = encode_file_path(path) + + if mode in ('r', 'rb'): + self._open_readable(c_path, maybe_unbox_memory_pool(memory_pool)) + elif mode in ('w', 'wb'): + self._open_writable(c_path) + elif mode in ('a', 'ab'): + self._open_writable(c_path, append=True) + else: + raise ValueError('Invalid file mode: {0}'.format(mode)) + + cdef _open_readable(self, c_string path, CMemoryPool* pool): + cdef shared_ptr[ReadableFile] handle + + with nogil: + handle = GetResultValue(ReadableFile.Open(path, pool)) + + self.is_readable = True + self.set_random_access_file( handle) + + cdef _open_writable(self, c_string path, c_bool append=False): + with nogil: + self.output_stream = GetResultValue( + FileOutputStream.OpenWithAppend(path, append) + ) + self.is_writable = True + self._is_appending = append + + def fileno(self): + self._assert_open() + return self.handle.file_descriptor() + + +cdef class FixedSizeBufferWriter(NativeFile): + """ + A stream writing to a Arrow buffer. + + Examples + -------- + Create a stream to write to ``pyarrow.Buffer``: + + >>> import pyarrow as pa + >>> buf = pa.allocate_buffer(5) + >>> with pa.output_stream(buf) as stream: + ... stream.write(b'abcde') + ... stream + ... + 5 + + + Inspect the buffer: + + >>> buf.to_pybytes() + b'abcde' + >>> buf + + """ + + def __cinit__(self, Buffer buffer): + self.output_stream.reset(new CFixedSizeBufferWriter(buffer.buffer)) + self.is_writable = True + + def set_memcopy_threads(self, int num_threads): + """ + Parameters + ---------- + num_threads : int + """ + cdef CFixedSizeBufferWriter* writer = \ + self.output_stream.get() + writer.set_memcopy_threads(num_threads) + + def set_memcopy_blocksize(self, int64_t blocksize): + """ + Parameters + ---------- + blocksize : int64 + """ + cdef CFixedSizeBufferWriter* writer = \ + self.output_stream.get() + writer.set_memcopy_blocksize(blocksize) + + def set_memcopy_threshold(self, int64_t threshold): + """ + Parameters + ---------- + threshold : int64 + """ + cdef CFixedSizeBufferWriter* writer = \ + self.output_stream.get() + writer.set_memcopy_threshold(threshold) + + +# ---------------------------------------------------------------------- +# Arrow buffers + + +cdef class Buffer(_Weakrefable): + """ + The base class for all Arrow buffers. + + A buffer represents a contiguous memory area. Many buffers will own + their memory, though not all of them do. + """ + + def __cinit__(self): + pass + + def __init__(self): + raise TypeError("Do not call Buffer's constructor directly, use " + "`pyarrow.py_buffer` function instead.") + + cdef void init(self, const shared_ptr[CBuffer]& buffer): + self.buffer = buffer + self.shape[0] = self.size + self.strides[0] = (1) + + def __len__(self): + return self.size + + def __repr__(self): + name = f"pyarrow.{self.__class__.__name__}" + return (f"<{name} " + f"address={hex(self.address)} " + f"size={self.size} " + f"is_cpu={self.is_cpu} " + f"is_mutable={self.is_mutable}>") + + def _assert_cpu(self): + if not self.is_cpu: + raise NotImplementedError("Implemented only for data on CPU device") + + @property + def size(self): + """ + The buffer size in bytes. + """ + return self.buffer.get().size() + + @property + def address(self): + """ + The buffer's address, as an integer. + + The returned address may point to CPU or device memory. + Use `is_cpu()` to disambiguate. + """ + return self.buffer.get().address() + + def hex(self): + """ + Compute hexadecimal representation of the buffer. + + Returns + ------- + : bytes + """ + self._assert_cpu() + return self.buffer.get().ToHexString() + + @property + def is_mutable(self): + """ + Whether the buffer is mutable. + """ + return self.buffer.get().is_mutable() + + @property + def is_cpu(self): + """ + Whether the buffer is CPU-accessible. + """ + return self.buffer.get().is_cpu() + + @property + def device(self): + """ + The device where the buffer resides. + + Returns + ------- + Device + """ + return Device.wrap(self.buffer.get().device()) + + @property + def memory_manager(self): + """ + The memory manager associated with the buffer. + + Returns + ------- + MemoryManager + """ + return MemoryManager.wrap(self.buffer.get().memory_manager()) + + @property + def device_type(self): + """ + The device type where the buffer resides. + + Returns + ------- + DeviceAllocationType + """ + return _wrap_device_allocation_type(self.buffer.get().device_type()) + + @property + def parent(self): + cdef shared_ptr[CBuffer] parent_buf = self.buffer.get().parent() + + if parent_buf.get() == NULL: + return None + else: + return pyarrow_wrap_buffer(parent_buf) + + def __getitem__(self, key): + if isinstance(key, slice): + if (key.step or 1) != 1: + raise IndexError('only slices with step 1 supported') + return _normalize_slice(self, key) + + return self.getitem(_normalize_index(key, self.size)) + + cdef getitem(self, int64_t i): + self._assert_cpu() + return self.buffer.get().data()[i] + + def slice(self, offset=0, length=None): + """ + Slice this buffer. Memory is not copied. + + You can also use the Python slice notation ``buffer[start:stop]``. + + Parameters + ---------- + offset : int, default 0 + Offset from start of buffer to slice. + length : int, default None + Length of slice (default is until end of Buffer starting from + offset). + + Returns + ------- + sliced : Buffer + A logical view over this buffer. + """ + cdef shared_ptr[CBuffer] result + + if offset < 0: + raise IndexError('Offset must be non-negative') + + if length is None: + result = GetResultValue(SliceBufferSafe(self.buffer, offset)) + else: + result = GetResultValue(SliceBufferSafe(self.buffer, offset, + length)) + return pyarrow_wrap_buffer(result) + + def equals(self, Buffer other): + """ + Determine if two buffers contain exactly the same data. + + Parameters + ---------- + other : Buffer + + Returns + ------- + are_equal : bool + True if buffer contents and size are equal + """ + if self.device != other.device: + raise ValueError( + "Device on which the data resides differs between buffers: " + f"{self.device.type_name} and {other.device.type_name}." + ) + if not self.is_cpu: + if self.address != other.address: + raise NotImplementedError( + "Implemented only for data on CPU device or data with equal " + "addresses" + ) + + cdef c_bool result = False + with nogil: + result = self.buffer.get().Equals(deref(other.buffer.get())) + return result + + def __eq__(self, other): + if isinstance(other, Buffer): + return self.equals(other) + else: + return self.equals(py_buffer(other)) + + def __reduce_ex__(self, protocol): + self._assert_cpu() + + if protocol >= 5: + bufobj = pickle.PickleBuffer(self) + elif self.buffer.get().is_mutable(): + # Need to pass a bytearray to recreate a mutable buffer when + # unpickling. + bufobj = PyByteArray_FromStringAndSize( + self.buffer.get().data(), + self.buffer.get().size()) + else: + bufobj = self.to_pybytes() + return py_buffer, (bufobj,) + + def to_pybytes(self): + """ + Return this buffer as a Python bytes object. Memory is copied. + """ + self._assert_cpu() + + return cp.PyBytes_FromStringAndSize( + self.buffer.get().data(), + self.buffer.get().size()) + + def __getbuffer__(self, cp.Py_buffer* buffer, int flags): + self._assert_cpu() + + if self.buffer.get().is_mutable(): + buffer.readonly = 0 + else: + if flags & cp.PyBUF_WRITABLE: + raise BufferError("Writable buffer requested but Arrow " + "buffer was not mutable") + buffer.readonly = 1 + buffer.buf = self.buffer.get().data() + buffer.len = self.size + if buffer.buf == NULL: + # ARROW-16048: Ensure we don't export a NULL address. + assert buffer.len == 0 + buffer.buf = cp.PyBytes_AS_STRING(b"") + buffer.format = 'b' + buffer.internal = NULL + buffer.itemsize = 1 + buffer.ndim = 1 + buffer.obj = self + buffer.shape = self.shape + buffer.strides = self.strides + buffer.suboffsets = NULL + + +cdef class ResizableBuffer(Buffer): + """ + A base class for buffers that can be resized. + """ + + cdef void init_rz(self, const shared_ptr[CResizableBuffer]& buffer): + self.init( buffer) + + def resize(self, int64_t new_size, shrink_to_fit=False): + """ + Resize buffer to indicated size. + + Parameters + ---------- + new_size : int + New size of buffer (padding may be added internally). + shrink_to_fit : bool, default False + If this is true, the buffer is shrunk when new_size is less + than the current size. + If this is false, the buffer is never shrunk. + """ + cdef c_bool c_shrink_to_fit = shrink_to_fit + with nogil: + check_status(( self.buffer.get()) + .Resize(new_size, c_shrink_to_fit)) + + +cdef shared_ptr[CResizableBuffer] _allocate_buffer(CMemoryPool* pool) except *: + with nogil: + return to_shared(GetResultValue(AllocateResizableBuffer(0, pool))) + + +def allocate_buffer(int64_t size, MemoryPool memory_pool=None, + resizable=False): + """ + Allocate a mutable buffer. + + Parameters + ---------- + size : int + Number of bytes to allocate (plus internal padding) + memory_pool : MemoryPool, optional + The pool to allocate memory from. + If not given, the default memory pool is used. + resizable : bool, default False + If true, the returned buffer is resizable. + + Returns + ------- + buffer : Buffer or ResizableBuffer + """ + cdef: + CMemoryPool* cpool = maybe_unbox_memory_pool(memory_pool) + shared_ptr[CResizableBuffer] c_rz_buffer + shared_ptr[CBuffer] c_buffer + + if resizable: + with nogil: + c_rz_buffer = to_shared(GetResultValue( + AllocateResizableBuffer(size, cpool))) + return pyarrow_wrap_resizable_buffer(c_rz_buffer) + else: + with nogil: + c_buffer = to_shared(GetResultValue(AllocateBuffer(size, cpool))) + return pyarrow_wrap_buffer(c_buffer) + + +cdef class BufferOutputStream(NativeFile): + """ + An output stream that writes to a resizable buffer. + + The buffer is produced as a result when ``getvalue()`` is called. + + Examples + -------- + Create an output stream, write data to it and finalize it with + ``getvalue()``: + + >>> import pyarrow as pa + >>> f = pa.BufferOutputStream() + >>> f.write(b'pyarrow.Buffer') + 14 + >>> f.closed + False + >>> f.getvalue() + + >>> f.closed + True + """ + + cdef: + shared_ptr[CResizableBuffer] buffer + + def __cinit__(self, MemoryPool memory_pool=None): + self.buffer = _allocate_buffer(maybe_unbox_memory_pool(memory_pool)) + self.output_stream.reset(new CBufferOutputStream( + self.buffer)) + self.is_writable = True + + def getvalue(self): + """ + Finalize output stream and return result as pyarrow.Buffer. + + Returns + ------- + value : Buffer + """ + with nogil: + check_status(self.output_stream.get().Close()) + return pyarrow_wrap_buffer( self.buffer) + + +cdef class MockOutputStream(NativeFile): + + def __cinit__(self): + self.output_stream.reset(new CMockOutputStream()) + self.is_writable = True + + def size(self): + handle = self.output_stream.get() + return handle.GetExtentBytesWritten() + + +cdef class BufferReader(NativeFile): + """ + Zero-copy reader from objects convertible to Arrow buffer. + + Parameters + ---------- + obj : Python bytes or pyarrow.Buffer + + Examples + -------- + Create an Arrow input stream and inspect it: + + >>> import pyarrow as pa + >>> data = b'reader data' + >>> buf = memoryview(data) + >>> with pa.input_stream(buf) as stream: + ... stream.size() + ... stream.read(6) + ... stream.seek(7) + ... stream.read(15) + ... + 11 + b'reader' + 7 + b'data' + """ + cdef: + Buffer buffer + + # XXX Needed to make numpydoc happy + def __init__(self, obj): + pass + + def __cinit__(self, object obj): + self.buffer = as_buffer(obj) + self.set_random_access_file(shared_ptr[CRandomAccessFile]( + new CBufferReader(self.buffer.buffer))) + self.is_readable = True + + +cdef class CompressedInputStream(NativeFile): + """ + An input stream wrapper which decompresses data on the fly. + + Parameters + ---------- + stream : string, path, pyarrow.NativeFile, or file-like object + Input stream object to wrap with the compression. + compression : str + The compression type ("bz2", "brotli", "gzip", "lz4" or "zstd"). + + Examples + -------- + Create an output stream which compresses the data: + + >>> import pyarrow as pa + >>> data = b"Compressed stream" + >>> raw = pa.BufferOutputStream() + >>> with pa.CompressedOutputStream(raw, "gzip") as compressed: + ... compressed.write(data) + ... + 17 + + Create an input stream with decompression referencing the + buffer with compressed data: + + >>> cdata = raw.getvalue() + >>> with pa.input_stream(cdata, compression="gzip") as compressed: + ... compressed.read() + ... + b'Compressed stream' + + which actually translates to the use of ``BufferReader``and + ``CompressedInputStream``: + + >>> raw = pa.BufferReader(cdata) + >>> with pa.CompressedInputStream(raw, "gzip") as compressed: + ... compressed.read() + ... + b'Compressed stream' + """ + + def __init__(self, object stream, str compression not None): + cdef: + NativeFile nf + Codec codec = Codec(compression) + shared_ptr[CInputStream] c_reader + shared_ptr[CCompressedInputStream] compressed_stream + nf = get_native_file(stream, False) + c_reader = nf.get_input_stream() + compressed_stream = GetResultValue( + CCompressedInputStream.Make(codec.unwrap(), c_reader) + ) + self.set_input_stream( compressed_stream) + self.is_readable = True + + +cdef class CompressedOutputStream(NativeFile): + """ + An output stream wrapper which compresses data on the fly. + + Parameters + ---------- + stream : string, path, pyarrow.NativeFile, or file-like object + Input stream object to wrap with the compression. + compression : str + The compression type ("bz2", "brotli", "gzip", "lz4" or "zstd"). + + Examples + -------- + Create an output stream which compresses the data: + + >>> import pyarrow as pa + >>> data = b"Compressed stream" + >>> raw = pa.BufferOutputStream() + >>> with pa.CompressedOutputStream(raw, "gzip") as compressed: + ... compressed.write(data) + ... + 17 + """ + + def __init__(self, object stream, str compression not None): + cdef: + Codec codec = Codec(compression) + shared_ptr[COutputStream] c_writer + shared_ptr[CCompressedOutputStream] compressed_stream + get_writer(stream, &c_writer) + compressed_stream = GetResultValue( + CCompressedOutputStream.Make(codec.unwrap(), c_writer) + ) + self.set_output_stream( compressed_stream) + self.is_writable = True + + +ctypedef CBufferedInputStream* _CBufferedInputStreamPtr +ctypedef CBufferedOutputStream* _CBufferedOutputStreamPtr +ctypedef CRandomAccessFile* _RandomAccessFilePtr + + +cdef class BufferedInputStream(NativeFile): + """ + An input stream that performs buffered reads from + an unbuffered input stream, which can mitigate the overhead + of many small reads in some cases. + + Parameters + ---------- + stream : NativeFile + The input stream to wrap with the buffer + buffer_size : int + Size of the temporary read buffer. + memory_pool : MemoryPool + The memory pool used to allocate the buffer. + """ + + def __init__(self, NativeFile stream, int buffer_size, + MemoryPool memory_pool=None): + cdef shared_ptr[CBufferedInputStream] buffered_stream + + if buffer_size <= 0: + raise ValueError('Buffer size must be larger than zero') + buffered_stream = GetResultValue(CBufferedInputStream.Create( + buffer_size, maybe_unbox_memory_pool(memory_pool), + stream.get_input_stream())) + + self.set_input_stream( buffered_stream) + self.is_readable = True + + def detach(self): + """ + Release the raw InputStream. + Further operations on this stream are invalid. + + Returns + ------- + raw : NativeFile + The underlying raw input stream + """ + cdef: + shared_ptr[CInputStream] c_raw + _CBufferedInputStreamPtr buffered + NativeFile raw + + buffered = dynamic_cast[_CBufferedInputStreamPtr]( + self.input_stream.get()) + assert buffered != nullptr + + with nogil: + c_raw = GetResultValue(buffered.Detach()) + + raw = NativeFile() + raw.is_readable = True + # Find out whether the raw stream is a RandomAccessFile + # or a mere InputStream. This helps us support seek() etc. + # selectively. + if dynamic_cast[_RandomAccessFilePtr](c_raw.get()) != nullptr: + raw.set_random_access_file( + static_pointer_cast[CRandomAccessFile, CInputStream](c_raw)) + else: + raw.set_input_stream(c_raw) + return raw + + +cdef class BufferedOutputStream(NativeFile): + """ + An output stream that performs buffered reads from + an unbuffered output stream, which can mitigate the overhead + of many small writes in some cases. + + Parameters + ---------- + stream : NativeFile + The writable output stream to wrap with the buffer + buffer_size : int + Size of the buffer that should be added. + memory_pool : MemoryPool + The memory pool used to allocate the buffer. + """ + + def __init__(self, NativeFile stream, int buffer_size, + MemoryPool memory_pool=None): + cdef shared_ptr[CBufferedOutputStream] buffered_stream + + if buffer_size <= 0: + raise ValueError('Buffer size must be larger than zero') + buffered_stream = GetResultValue(CBufferedOutputStream.Create( + buffer_size, maybe_unbox_memory_pool(memory_pool), + stream.get_output_stream())) + + self.set_output_stream( buffered_stream) + self.is_writable = True + + def detach(self): + """ + Flush any buffered writes and release the raw OutputStream. + Further operations on this stream are invalid. + + Returns + ------- + raw : NativeFile + The underlying raw output stream. + """ + cdef: + shared_ptr[COutputStream] c_raw + _CBufferedOutputStreamPtr buffered + NativeFile raw + + buffered = dynamic_cast[_CBufferedOutputStreamPtr]( + self.output_stream.get()) + assert buffered != nullptr + + with nogil: + c_raw = GetResultValue(buffered.Detach()) + + raw = NativeFile() + raw.is_writable = True + raw.set_output_stream(c_raw) + return raw + + +cdef void _cb_transform(transform_func, const shared_ptr[CBuffer]& src, + shared_ptr[CBuffer]* dest) except *: + py_dest = transform_func(pyarrow_wrap_buffer(src)) + dest[0] = pyarrow_unwrap_buffer(py_buffer(py_dest)) + + +cdef class TransformInputStream(NativeFile): + """ + Transform an input stream. + + Parameters + ---------- + stream : NativeFile + The stream to transform. + transform_func : callable + The transformation to apply. + """ + + def __init__(self, NativeFile stream, transform_func): + self.set_input_stream(TransformInputStream.make_native( + stream.get_input_stream(), transform_func)) + self.is_readable = True + + @staticmethod + cdef shared_ptr[CInputStream] make_native( + shared_ptr[CInputStream] stream, transform_func) except *: + cdef: + shared_ptr[CInputStream] transform_stream + CTransformInputStreamVTable vtable + + vtable.transform = _cb_transform + return MakeTransformInputStream(stream, move(vtable), + transform_func) + + +class Transcoder: + + def __init__(self, decoder, encoder): + self._decoder = decoder + self._encoder = encoder + + def __call__(self, buf): + final = len(buf) == 0 + return self._encoder.encode(self._decoder.decode(buf, final), final) + + +cdef shared_ptr[function[StreamWrapFunc]] make_streamwrap_func( + src_encoding, dest_encoding) except *: + """ + Create a function that will add a transcoding transformation to a stream. + Data from that stream will be decoded according to ``src_encoding`` and + then re-encoded according to ``dest_encoding``. + The created function can be used to wrap streams. + + Parameters + ---------- + src_encoding : str + The codec to use when reading data. + dest_encoding : str + The codec to use for emitted data. + """ + cdef: + shared_ptr[function[StreamWrapFunc]] empty_func + CTransformInputStreamVTable vtable + + vtable.transform = _cb_transform + src_codec = codecs.lookup(src_encoding) + dest_codec = codecs.lookup(dest_encoding) + return MakeStreamTransformFunc(move(vtable), + Transcoder(src_codec.incrementaldecoder(), + dest_codec.incrementalencoder())) + + +def transcoding_input_stream(stream, src_encoding, dest_encoding): + """ + Add a transcoding transformation to the stream. + Incoming data will be decoded according to ``src_encoding`` and + then re-encoded according to ``dest_encoding``. + + Parameters + ---------- + stream : NativeFile + The stream to which the transformation should be applied. + src_encoding : str + The codec to use when reading data. + dest_encoding : str + The codec to use for emitted data. + """ + src_codec = codecs.lookup(src_encoding) + dest_codec = codecs.lookup(dest_encoding) + if src_codec.name == dest_codec.name: + # Avoid losing performance on no-op transcoding + # (encoding errors won't be detected) + return stream + return TransformInputStream(stream, + Transcoder(src_codec.incrementaldecoder(), + dest_codec.incrementalencoder())) + + +cdef shared_ptr[CInputStream] native_transcoding_input_stream( + shared_ptr[CInputStream] stream, src_encoding, + dest_encoding) except *: + src_codec = codecs.lookup(src_encoding) + dest_codec = codecs.lookup(dest_encoding) + if src_codec.name == dest_codec.name: + # Avoid losing performance on no-op transcoding + # (encoding errors won't be detected) + return stream + return TransformInputStream.make_native( + stream, Transcoder(src_codec.incrementaldecoder(), + dest_codec.incrementalencoder())) + + +def py_buffer(object obj): + """ + Construct an Arrow buffer from a Python bytes-like or buffer-like object + + Parameters + ---------- + obj : object + the object from which the buffer should be constructed. + """ + cdef shared_ptr[CBuffer] buf + buf = GetResultValue(PyBuffer.FromPyObject(obj)) + return pyarrow_wrap_buffer(buf) + + +def foreign_buffer(address, size, base=None): + """ + Construct an Arrow buffer with the given *address* and *size*. + + The buffer will be optionally backed by the Python *base* object, if given. + The *base* object will be kept alive as long as this buffer is alive, + including across language boundaries (for example if the buffer is + referenced by C++ code). + + Parameters + ---------- + address : int + The starting address of the buffer. The address can + refer to both device or host memory but it must be + accessible from device after mapping it with + `get_device_address` method. + size : int + The size of device buffer in bytes. + base : {None, object} + Object that owns the referenced memory. + """ + cdef: + uintptr_t c_addr = address + int64_t c_size = size + shared_ptr[CBuffer] buf + + check_status(PyForeignBuffer.Make( c_addr, c_size, + base, &buf)) + return pyarrow_wrap_buffer(buf) + + +def as_buffer(object o): + if isinstance(o, Buffer): + return o + return py_buffer(o) + + +cdef shared_ptr[CBuffer] as_c_buffer(object o) except *: + cdef shared_ptr[CBuffer] buf + if isinstance(o, Buffer): + buf = ( o).buffer + if buf == nullptr: + raise ValueError("got null buffer") + else: + buf = GetResultValue(PyBuffer.FromPyObject(o)) + return buf + + +cdef NativeFile get_native_file(object source, c_bool use_memory_map): + try: + source_path = _stringify_path(source) + except TypeError: + if isinstance(source, Buffer): + source = BufferReader(source) + elif not isinstance(source, NativeFile) and hasattr(source, 'read'): + # Optimistically hope this is file-like + source = PythonFile(source, mode='r') + else: + if use_memory_map: + source = memory_map(source_path, mode='r') + else: + source = OSFile(source_path, mode='r') + + return source + + +cdef get_reader(object source, c_bool use_memory_map, + shared_ptr[CRandomAccessFile]* reader): + cdef NativeFile nf + + nf = get_native_file(source, use_memory_map) + reader[0] = nf.get_random_access_file() + + +cdef get_input_stream(object source, c_bool use_memory_map, + shared_ptr[CInputStream]* out): + """ + Like get_reader(), but can automatically decompress, and returns + an InputStream. + """ + cdef: + NativeFile nf + Codec codec + shared_ptr[CInputStream] input_stream + + try: + codec = Codec.detect(source) + except TypeError: + codec = None + + nf = get_native_file(source, use_memory_map) + input_stream = nf.get_input_stream() + + # codec is None if compression can't be detected + if codec is not None: + input_stream = GetResultValue( + CCompressedInputStream.Make(codec.unwrap(), input_stream) + ) + + out[0] = input_stream + + +cdef get_writer(object source, shared_ptr[COutputStream]* writer): + cdef NativeFile nf + + try: + source_path = _stringify_path(source) + except TypeError: + if not isinstance(source, NativeFile) and hasattr(source, 'write'): + # Optimistically hope this is file-like + source = PythonFile(source, mode='w') + else: + source = OSFile(source_path, mode='w') + + if isinstance(source, NativeFile): + nf = source + writer[0] = nf.get_output_stream() + else: + raise TypeError('Unable to write to object of type: {0}' + .format(type(source))) + + +# --------------------------------------------------------------------- + + +def _detect_compression(path): + if isinstance(path, str): + if path.endswith('.bz2'): + return 'bz2' + elif path.endswith('.gz'): + return 'gzip' + elif path.endswith('.lz4'): + return 'lz4' + elif path.endswith('.zst'): + return 'zstd' + + +cdef CCompressionType _ensure_compression(str name) except *: + uppercase = name.upper() + if uppercase == 'BZ2': + return CCompressionType_BZ2 + elif uppercase == 'GZIP': + return CCompressionType_GZIP + elif uppercase == 'BROTLI': + return CCompressionType_BROTLI + elif uppercase == 'LZ4' or uppercase == 'LZ4_FRAME': + return CCompressionType_LZ4_FRAME + elif uppercase == 'LZ4_RAW': + return CCompressionType_LZ4 + elif uppercase == 'SNAPPY': + return CCompressionType_SNAPPY + elif uppercase == 'ZSTD': + return CCompressionType_ZSTD + else: + raise ValueError('Invalid value for compression: {!r}'.format(name)) + + +cdef class CacheOptions(_Weakrefable): + """ + Cache options for a pre-buffered fragment scan. + + Parameters + ---------- + hole_size_limit : int, default 8KiB + The maximum distance in bytes between two consecutive ranges; beyond + this value, ranges are not combined. + range_size_limit : int, default 32MiB + The maximum size in bytes of a combined range; if combining two + consecutive ranges would produce a range of a size greater than this, + they are not combined + lazy : bool, default True + lazy = false: request all byte ranges when PreBuffer or WillNeed is called. + lazy = True, prefetch_limit = 0: request merged byte ranges only after the reader + needs them. + lazy = True, prefetch_limit = k: prefetch up to k merged byte ranges ahead of the + range that is currently being read. + prefetch_limit : int, default 0 + The maximum number of ranges to be prefetched. This is only used for + lazy cache to asynchronously read some ranges after reading the target + range. + """ + + def __init__(self, *, hole_size_limit=None, range_size_limit=None, lazy=None, prefetch_limit=None): + self.wrapped = CCacheOptions.LazyDefaults() + if hole_size_limit is not None: + self.hole_size_limit = hole_size_limit + if range_size_limit is not None: + self.range_size_limit = range_size_limit + if lazy is not None: + self.lazy = lazy + if prefetch_limit is not None: + self.prefetch_limit = prefetch_limit + + cdef void init(self, CCacheOptions options): + self.wrapped = options + + cdef inline CCacheOptions unwrap(self): + return self.wrapped + + @staticmethod + cdef wrap(CCacheOptions options): + self = CacheOptions() + self.init(options) + return self + + @property + def hole_size_limit(self): + return self.wrapped.hole_size_limit + + @hole_size_limit.setter + def hole_size_limit(self, hole_size_limit): + self.wrapped.hole_size_limit = hole_size_limit + + @property + def range_size_limit(self): + return self.wrapped.range_size_limit + + @range_size_limit.setter + def range_size_limit(self, range_size_limit): + self.wrapped.range_size_limit = range_size_limit + + @property + def lazy(self): + return self.wrapped.lazy + + @lazy.setter + def lazy(self, lazy): + self.wrapped.lazy = lazy + + @property + def prefetch_limit(self): + return self.wrapped.prefetch_limit + + @prefetch_limit.setter + def prefetch_limit(self, prefetch_limit): + self.wrapped.prefetch_limit = prefetch_limit + + def __eq__(self, CacheOptions other): + try: + return self.unwrap().Equals(other.unwrap()) + except TypeError: + return False + + @staticmethod + def from_network_metrics(time_to_first_byte_millis, transfer_bandwidth_mib_per_sec, + ideal_bandwidth_utilization_frac=0.9, max_ideal_request_size_mib=64): + """ + Create suitable CacheOptions based on provided network metrics. + + Typically this will be used with object storage solutions like Amazon S3, + Google Cloud Storage and Azure Blob Storage. + + Parameters + ---------- + time_to_first_byte_millis : int + Seek-time or Time-To-First-Byte (TTFB) in milliseconds, also called call + setup latency of a new read request. The value is a positive integer. + transfer_bandwidth_mib_per_sec : int + Data transfer Bandwidth (BW) in MiB/sec (per connection). The value is a positive + integer. + ideal_bandwidth_utilization_frac : int, default 0.9 + Transfer bandwidth utilization fraction (per connection) to maximize the net + data load. The value is a positive float less than 1. + max_ideal_request_size_mib : int, default 64 + The maximum single data request size (in MiB) to maximize the net data load. + + Returns + ------- + CacheOptions + """ + return CacheOptions.wrap(CCacheOptions.MakeFromNetworkMetrics( + time_to_first_byte_millis, transfer_bandwidth_mib_per_sec, + ideal_bandwidth_utilization_frac, max_ideal_request_size_mib)) + + @staticmethod + @binding(True) # Required for Cython < 3 + def _reconstruct(kwargs): + # __reduce__ doesn't allow passing named arguments directly to the + # reconstructor, hence this wrapper. + return CacheOptions(**kwargs) + + def __reduce__(self): + kwargs = dict( + hole_size_limit=self.hole_size_limit, + range_size_limit=self.range_size_limit, + lazy=self.lazy, + prefetch_limit=self.prefetch_limit, + ) + return CacheOptions._reconstruct, (kwargs,) + + +cdef class Codec(_Weakrefable): + """ + Compression codec. + + Parameters + ---------- + compression : str + Type of compression codec to initialize, valid values are: 'gzip', + 'bz2', 'brotli', 'lz4' (or 'lz4_frame'), 'lz4_raw', 'zstd' and + 'snappy'. + compression_level : int, None + Optional parameter specifying how aggressively to compress. The + possible ranges and effect of this parameter depend on the specific + codec chosen. Higher values compress more but typically use more + resources (CPU/RAM). Some codecs support negative values. + + gzip + The compression_level maps to the memlevel parameter of + deflateInit2. Higher levels use more RAM but are faster + and should have higher compression ratios. + + bz2 + The compression level maps to the blockSize100k parameter of + the BZ2_bzCompressInit function. Higher levels use more RAM + but are faster and should have higher compression ratios. + + brotli + The compression level maps to the BROTLI_PARAM_QUALITY + parameter. Higher values are slower and should have higher + compression ratios. + + lz4/lz4_frame/lz4_raw + The compression level parameter is not supported and must + be None + + zstd + The compression level maps to the compressionLevel parameter + of ZSTD_initCStream. Negative values are supported. Higher + values are slower and should have higher compression ratios. + + snappy + The compression level parameter is not supported and must + be None + + + Raises + ------ + ValueError + If invalid compression value is passed. + + Examples + -------- + >>> import pyarrow as pa + >>> pa.Codec.is_available('gzip') + True + >>> codec = pa.Codec('gzip') + >>> codec.name + 'gzip' + >>> codec.compression_level + 9 + """ + + def __init__(self, str compression not None, compression_level=None): + cdef CCompressionType typ = _ensure_compression(compression) + if compression_level is not None: + self.wrapped = shared_ptr[CCodec](move(GetResultValue( + CCodec.CreateWithLevel(typ, compression_level)))) + else: + self.wrapped = shared_ptr[CCodec](move(GetResultValue( + CCodec.Create(typ)))) + + cdef inline CCodec* unwrap(self) nogil: + return self.wrapped.get() + + @staticmethod + def detect(path): + """ + Detect and instantiate compression codec based on file extension. + + Parameters + ---------- + path : str, path-like + File-path to detect compression from. + + Raises + ------ + TypeError + If the passed value is not path-like. + ValueError + If the compression can't be detected from the path. + + Returns + ------- + Codec + """ + return Codec(_detect_compression(_stringify_path(path))) + + @staticmethod + def is_available(str compression not None): + """ + Returns whether the compression support has been built and enabled. + + Parameters + ---------- + compression : str + Type of compression codec, + refer to Codec docstring for a list of supported ones. + + Returns + ------- + bool + """ + cdef CCompressionType typ = _ensure_compression(compression) + return CCodec.IsAvailable(typ) + + @staticmethod + def supports_compression_level(str compression not None): + """ + Returns true if the compression level parameter is supported + for the given codec. + + Parameters + ---------- + compression : str + Type of compression codec, + refer to Codec docstring for a list of supported ones. + """ + cdef CCompressionType typ = _ensure_compression(compression) + return CCodec.SupportsCompressionLevel(typ) + + @staticmethod + def default_compression_level(str compression not None): + """ + Returns the compression level that Arrow will use for the codec if + None is specified. + + Parameters + ---------- + compression : str + Type of compression codec, + refer to Codec docstring for a list of supported ones. + """ + cdef CCompressionType typ = _ensure_compression(compression) + return GetResultValue(CCodec.DefaultCompressionLevel(typ)) + + @staticmethod + def minimum_compression_level(str compression not None): + """ + Returns the smallest valid value for the compression level + + Parameters + ---------- + compression : str + Type of compression codec, + refer to Codec docstring for a list of supported ones. + """ + cdef CCompressionType typ = _ensure_compression(compression) + return GetResultValue(CCodec.MinimumCompressionLevel(typ)) + + @staticmethod + def maximum_compression_level(str compression not None): + """ + Returns the largest valid value for the compression level + + Parameters + ---------- + compression : str + Type of compression codec, + refer to Codec docstring for a list of supported ones. + """ + cdef CCompressionType typ = _ensure_compression(compression) + return GetResultValue(CCodec.MaximumCompressionLevel(typ)) + + @property + def name(self): + """Returns the name of the codec""" + return frombytes(self.unwrap().name()) + + @property + def compression_level(self): + """Returns the compression level parameter of the codec""" + if self.name == 'snappy': + return None + return self.unwrap().compression_level() + + def compress(self, object buf, asbytes=False, memory_pool=None): + """ + Compress data from buffer-like object. + + Parameters + ---------- + buf : pyarrow.Buffer, bytes, or other object supporting buffer protocol + asbytes : bool, default False + Return result as Python bytes object, otherwise Buffer + memory_pool : MemoryPool, default None + Memory pool to use for buffer allocations, if any + + Returns + ------- + compressed : pyarrow.Buffer or bytes (if asbytes=True) + """ + cdef: + shared_ptr[CBuffer] owned_buf + CBuffer* c_buf + PyObject* pyobj + ResizableBuffer out_buf + int64_t max_output_size + int64_t output_length + uint8_t* output_buffer = NULL + + owned_buf = as_c_buffer(buf) + c_buf = owned_buf.get() + + max_output_size = self.wrapped.get().MaxCompressedLen( + c_buf.size(), c_buf.data() + ) + + if asbytes: + pyobj = PyBytes_FromStringAndSizeNative(NULL, max_output_size) + output_buffer = cp.PyBytes_AS_STRING( pyobj) + else: + out_buf = allocate_buffer( + max_output_size, memory_pool=memory_pool, resizable=True + ) + output_buffer = out_buf.buffer.get().mutable_data() + + with nogil: + output_length = GetResultValue( + self.unwrap().Compress( + c_buf.size(), + c_buf.data(), + max_output_size, + output_buffer + ) + ) + + if asbytes: + cp._PyBytes_Resize(&pyobj, output_length) + return PyObject_to_object(pyobj) + else: + out_buf.resize(output_length) + return out_buf + + def decompress(self, object buf, decompressed_size=None, asbytes=False, + memory_pool=None): + """ + Decompress data from buffer-like object. + + Parameters + ---------- + buf : pyarrow.Buffer, bytes, or memoryview-compatible object + decompressed_size : int, default None + Size of the decompressed result + asbytes : boolean, default False + Return result as Python bytes object, otherwise Buffer + memory_pool : MemoryPool, default None + Memory pool to use for buffer allocations, if any. + + Returns + ------- + uncompressed : pyarrow.Buffer or bytes (if asbytes=True) + """ + cdef: + shared_ptr[CBuffer] owned_buf + CBuffer* c_buf + Buffer out_buf + int64_t output_size + uint8_t* output_buffer = NULL + + owned_buf = as_c_buffer(buf) + c_buf = owned_buf.get() + + if decompressed_size is None: + raise ValueError( + "Must pass decompressed_size" + ) + + output_size = decompressed_size + + if asbytes: + pybuf = cp.PyBytes_FromStringAndSize(NULL, output_size) + output_buffer = cp.PyBytes_AS_STRING(pybuf) + else: + out_buf = allocate_buffer(output_size, memory_pool=memory_pool) + output_buffer = out_buf.buffer.get().mutable_data() + + with nogil: + GetResultValue( + self.unwrap().Decompress( + c_buf.size(), + c_buf.data(), + output_size, + output_buffer + ) + ) + + return pybuf if asbytes else out_buf + + def __repr__(self): + name = f"pyarrow.{self.__class__.__name__}" + return (f"<{name} " + f"name={self.name} " + f"compression_level={self.compression_level}>") + + +def compress(object buf, codec='lz4', asbytes=False, memory_pool=None): + """ + Compress data from buffer-like object. + + Parameters + ---------- + buf : pyarrow.Buffer, bytes, or other object supporting buffer protocol + codec : str, default 'lz4' + Compression codec. + Supported types: {'brotli, 'gzip', 'lz4', 'lz4_raw', 'snappy', 'zstd'} + asbytes : bool, default False + Return result as Python bytes object, otherwise Buffer. + memory_pool : MemoryPool, default None + Memory pool to use for buffer allocations, if any. + + Returns + ------- + compressed : pyarrow.Buffer or bytes (if asbytes=True) + """ + cdef Codec coder = Codec(codec) + return coder.compress(buf, asbytes=asbytes, memory_pool=memory_pool) + + +def decompress(object buf, decompressed_size=None, codec='lz4', + asbytes=False, memory_pool=None): + """ + Decompress data from buffer-like object. + + Parameters + ---------- + buf : pyarrow.Buffer, bytes, or memoryview-compatible object + Input object to decompress data from. + decompressed_size : int, default None + Size of the decompressed result + codec : str, default 'lz4' + Compression codec. + Supported types: {'brotli, 'gzip', 'lz4', 'lz4_raw', 'snappy', 'zstd'} + asbytes : bool, default False + Return result as Python bytes object, otherwise Buffer. + memory_pool : MemoryPool, default None + Memory pool to use for buffer allocations, if any. + + Returns + ------- + uncompressed : pyarrow.Buffer or bytes (if asbytes=True) + """ + cdef Codec decoder = Codec(codec) + return decoder.decompress(buf, asbytes=asbytes, memory_pool=memory_pool, + decompressed_size=decompressed_size) + + +def input_stream(source, compression='detect', buffer_size=None): + """ + Create an Arrow input stream. + + Parameters + ---------- + source : str, Path, buffer, or file-like object + The source to open for reading. + compression : str optional, default 'detect' + The compression algorithm to use for on-the-fly decompression. + If "detect" and source is a file path, then compression will be + chosen based on the file extension. + If None, no compression will be applied. + Otherwise, a well-known algorithm name must be supplied (e.g. "gzip"). + buffer_size : int, default None + If None or 0, no buffering will happen. Otherwise the size of the + temporary read buffer. + + Examples + -------- + Create a readable BufferReader (NativeFile) from a Buffer or a memoryview object: + + >>> import pyarrow as pa + >>> buf = memoryview(b"some data") + >>> with pa.input_stream(buf) as stream: + ... stream.read(4) + ... + b'some' + + Create a readable OSFile (NativeFile) from a string or file path: + + >>> import gzip + >>> with gzip.open('example.gz', 'wb') as f: + ... f.write(b'some data') + ... + 9 + >>> with pa.input_stream('example.gz') as stream: + ... stream.read() + ... + b'some data' + + Create a readable PythonFile (NativeFile) from a a Python file object: + + >>> with open('example.txt', mode='w') as f: + ... f.write('some text') + ... + 9 + >>> with pa.input_stream('example.txt') as stream: + ... stream.read(6) + ... + b'some t' + """ + cdef NativeFile stream + + try: + source_path = _stringify_path(source) + except TypeError: + source_path = None + + if isinstance(source, NativeFile): + stream = source + elif source_path is not None: + stream = OSFile(source_path, 'r') + elif isinstance(source, (Buffer, memoryview)): + stream = BufferReader(as_buffer(source)) + elif (hasattr(source, 'read') and + hasattr(source, 'close') and + hasattr(source, 'closed')): + stream = PythonFile(source, 'r') + else: + raise TypeError("pa.input_stream() called with instance of '{}'" + .format(source.__class__)) + + if compression == 'detect': + # detect for OSFile too + compression = _detect_compression(source_path) + + if buffer_size is not None and buffer_size != 0: + stream = BufferedInputStream(stream, buffer_size) + + if compression is not None: + stream = CompressedInputStream(stream, compression) + + return stream + + +def output_stream(source, compression='detect', buffer_size=None): + """ + Create an Arrow output stream. + + Parameters + ---------- + source : str, Path, buffer, file-like object + The source to open for writing. + compression : str optional, default 'detect' + The compression algorithm to use for on-the-fly compression. + If "detect" and source is a file path, then compression will be + chosen based on the file extension. + If None, no compression will be applied. + Otherwise, a well-known algorithm name must be supplied (e.g. "gzip"). + buffer_size : int, default None + If None or 0, no buffering will happen. Otherwise the size of the + temporary write buffer. + + Examples + -------- + Create a writable NativeFile from a pyarrow Buffer: + + >>> import pyarrow as pa + >>> data = b"buffer data" + >>> empty_obj = bytearray(11) + >>> buf = pa.py_buffer(empty_obj) + >>> with pa.output_stream(buf) as stream: + ... stream.write(data) + ... + 11 + >>> with pa.input_stream(buf) as stream: + ... stream.read(6) + ... + b'buffer' + + or from a memoryview object: + + >>> buf = memoryview(empty_obj) + >>> with pa.output_stream(buf) as stream: + ... stream.write(data) + ... + 11 + >>> with pa.input_stream(buf) as stream: + ... stream.read() + ... + b'buffer data' + + Create a writable NativeFile from a string or file path: + + >>> with pa.output_stream('example_second.txt') as stream: + ... stream.write(b'Write some data') + ... + 15 + >>> with pa.input_stream('example_second.txt') as stream: + ... stream.read() + ... + b'Write some data' + """ + cdef NativeFile stream + + try: + source_path = _stringify_path(source) + except TypeError: + source_path = None + + if isinstance(source, NativeFile): + stream = source + elif source_path is not None: + stream = OSFile(source_path, 'w') + elif isinstance(source, (Buffer, memoryview)): + stream = FixedSizeBufferWriter(as_buffer(source)) + elif (hasattr(source, 'write') and + hasattr(source, 'close') and + hasattr(source, 'closed')): + stream = PythonFile(source, 'w') + else: + raise TypeError("pa.output_stream() called with instance of '{}'" + .format(source.__class__)) + + if compression == 'detect': + compression = _detect_compression(source_path) + + if buffer_size is not None and buffer_size != 0: + stream = BufferedOutputStream(stream, buffer_size) + + if compression is not None: + stream = CompressedOutputStream(stream, compression) + + return stream diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/lib.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/lib.h new file mode 100644 index 0000000000000000000000000000000000000000..6856e5cba9558b8c09d15943ff20b210641150bf --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/lib.h @@ -0,0 +1,83 @@ +/* Generated by Cython 3.0.12 */ + +#ifndef __PYX_HAVE__pyarrow__lib +#define __PYX_HAVE__pyarrow__lib + +#include "Python.h" + +#ifndef __PYX_HAVE_API__pyarrow__lib + +#ifdef CYTHON_EXTERN_C + #undef __PYX_EXTERN_C + #define __PYX_EXTERN_C CYTHON_EXTERN_C +#elif defined(__PYX_EXTERN_C) + #ifdef _MSC_VER + #pragma message ("Please do not define the '__PYX_EXTERN_C' macro externally. Use 'CYTHON_EXTERN_C' instead.") + #else + #warning Please do not define the '__PYX_EXTERN_C' macro externally. Use 'CYTHON_EXTERN_C' instead. + #endif +#else + #define __PYX_EXTERN_C extern "C++" +#endif + +#ifndef DL_IMPORT + #define DL_IMPORT(_T) _T +#endif + +__PYX_EXTERN_C PyObject *pyarrow_wrap_buffer(std::shared_ptr< arrow::Buffer> const &); +__PYX_EXTERN_C PyObject *pyarrow_wrap_resizable_buffer(std::shared_ptr< arrow::ResizableBuffer> const &); +__PYX_EXTERN_C PyObject *pyarrow_wrap_data_type(std::shared_ptr< arrow::DataType> const &); +__PYX_EXTERN_C PyObject *pyarrow_wrap_field(std::shared_ptr< arrow::Field> const &); +__PYX_EXTERN_C PyObject *pyarrow_wrap_schema(std::shared_ptr< arrow::Schema> const &); +__PYX_EXTERN_C PyObject *pyarrow_wrap_scalar(std::shared_ptr< arrow::Scalar> const &); +__PYX_EXTERN_C PyObject *pyarrow_wrap_array(std::shared_ptr< arrow::Array> const &); +__PYX_EXTERN_C PyObject *pyarrow_wrap_chunked_array(std::shared_ptr< arrow::ChunkedArray> const &); +__PYX_EXTERN_C PyObject *pyarrow_wrap_sparse_coo_tensor(std::shared_ptr< arrow::SparseCOOTensor> const &); +__PYX_EXTERN_C PyObject *pyarrow_wrap_sparse_csc_matrix(std::shared_ptr< arrow::SparseCSCMatrix> const &); +__PYX_EXTERN_C PyObject *pyarrow_wrap_sparse_csf_tensor(std::shared_ptr< arrow::SparseCSFTensor> const &); +__PYX_EXTERN_C PyObject *pyarrow_wrap_sparse_csr_matrix(std::shared_ptr< arrow::SparseCSRMatrix> const &); +__PYX_EXTERN_C PyObject *pyarrow_wrap_tensor(std::shared_ptr< arrow::Tensor> const &); +__PYX_EXTERN_C PyObject *pyarrow_wrap_batch(std::shared_ptr< arrow::RecordBatch> const &); +__PYX_EXTERN_C PyObject *pyarrow_wrap_table(std::shared_ptr< arrow::Table> const &); +__PYX_EXTERN_C std::shared_ptr< arrow::Buffer> pyarrow_unwrap_buffer(PyObject *); +__PYX_EXTERN_C std::shared_ptr< arrow::DataType> pyarrow_unwrap_data_type(PyObject *); +__PYX_EXTERN_C std::shared_ptr< arrow::Field> pyarrow_unwrap_field(PyObject *); +__PYX_EXTERN_C std::shared_ptr< arrow::Schema> pyarrow_unwrap_schema(PyObject *); +__PYX_EXTERN_C std::shared_ptr< arrow::Scalar> pyarrow_unwrap_scalar(PyObject *); +__PYX_EXTERN_C std::shared_ptr< arrow::Array> pyarrow_unwrap_array(PyObject *); +__PYX_EXTERN_C std::shared_ptr< arrow::ChunkedArray> pyarrow_unwrap_chunked_array(PyObject *); +__PYX_EXTERN_C std::shared_ptr< arrow::SparseCOOTensor> pyarrow_unwrap_sparse_coo_tensor(PyObject *); +__PYX_EXTERN_C std::shared_ptr< arrow::SparseCSCMatrix> pyarrow_unwrap_sparse_csc_matrix(PyObject *); +__PYX_EXTERN_C std::shared_ptr< arrow::SparseCSFTensor> pyarrow_unwrap_sparse_csf_tensor(PyObject *); +__PYX_EXTERN_C std::shared_ptr< arrow::SparseCSRMatrix> pyarrow_unwrap_sparse_csr_matrix(PyObject *); +__PYX_EXTERN_C std::shared_ptr< arrow::Tensor> pyarrow_unwrap_tensor(PyObject *); +__PYX_EXTERN_C std::shared_ptr< arrow::RecordBatch> pyarrow_unwrap_batch(PyObject *); +__PYX_EXTERN_C std::shared_ptr< arrow::Table> pyarrow_unwrap_table(PyObject *); + +#endif /* !__PYX_HAVE_API__pyarrow__lib */ + +/* WARNING: the interface of the module init function changed in CPython 3.5. */ +/* It now returns a PyModuleDef instance instead of a PyModule instance. */ + +#if PY_MAJOR_VERSION < 3 +PyMODINIT_FUNC initlib(void); +#else +/* WARNING: Use PyImport_AppendInittab("lib", PyInit_lib) instead of calling PyInit_lib directly from Python 3.5 */ +PyMODINIT_FUNC PyInit_lib(void); + +#if PY_VERSION_HEX >= 0x03050000 && (defined(__GNUC__) || defined(__clang__) || defined(_MSC_VER) || (defined(__cplusplus) && __cplusplus >= 201402L)) +#if defined(__cplusplus) && __cplusplus >= 201402L +[[deprecated("Use PyImport_AppendInittab(\"lib\", PyInit_lib) instead of calling PyInit_lib directly.")]] inline +#elif defined(__GNUC__) || defined(__clang__) +__attribute__ ((__deprecated__("Use PyImport_AppendInittab(\"lib\", PyInit_lib) instead of calling PyInit_lib directly."), __unused__)) __inline__ +#elif defined(_MSC_VER) +__declspec(deprecated("Use PyImport_AppendInittab(\"lib\", PyInit_lib) instead of calling PyInit_lib directly.")) __inline +#endif +static PyObject* __PYX_WARN_IF_PyInit_lib_INIT_CALLED(PyObject* res) { + return res; +} +#define PyInit_lib() __PYX_WARN_IF_PyInit_lib_INIT_CALLED(PyInit_lib()) +#endif +#endif + +#endif /* !__PYX_HAVE__pyarrow__lib */ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/lib.pyx b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/lib.pyx new file mode 100644 index 0000000000000000000000000000000000000000..2c92ecbfa73446ac5cff3b785699672dbbe4293a --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/lib.pyx @@ -0,0 +1,237 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# cython: profile = False +# cython: nonecheck = True +# distutils: language = c++ + +import datetime +import decimal as _pydecimal +try: + import numpy as np +except ImportError: + np = None +import os +import sys + +from cython.operator cimport dereference as deref +from pyarrow.includes.libarrow cimport * +from pyarrow.includes.libarrow_python cimport * +from pyarrow.includes.common cimport PyObject_to_object +cimport pyarrow.includes.libarrow_python as libarrow_python +cimport cpython as cp + + +# Initialize NumPy C API only if numpy was able to be imported +if np is not None: + arrow_init_numpy() + +# Initialize PyArrow C++ API +# (used from some of our C++ code, see e.g. ARROW-5260) +import_pyarrow() + + +MonthDayNano = NewMonthDayNanoTupleType() + + +def cpu_count(): + """ + Return the number of threads to use in parallel operations. + + The number of threads is determined at startup by inspecting the + ``OMP_NUM_THREADS`` and ``OMP_THREAD_LIMIT`` environment variables. + If neither is present, it will default to the number of hardware threads + on the system. It can be modified at runtime by calling + :func:`set_cpu_count()`. + + See Also + -------- + set_cpu_count : Modify the size of this pool. + io_thread_count : The analogous function for the I/O thread pool. + """ + return GetCpuThreadPoolCapacity() + + +def set_cpu_count(int count): + """ + Set the number of threads to use in parallel operations. + + Parameters + ---------- + count : int + The number of concurrent threads that should be used. + + See Also + -------- + cpu_count : Get the size of this pool. + set_io_thread_count : The analogous function for the I/O thread pool. + """ + if count < 1: + raise ValueError("CPU count must be strictly positive") + check_status(SetCpuThreadPoolCapacity(count)) + + +def is_threading_enabled() -> bool: + """ + Returns True if threading is enabled in libarrow. + + If it isn't enabled, then python shouldn't create any + threads either, because we're probably on a system where + threading doesn't work (e.g. Emscripten). + """ + return libarrow_python.IsThreadingEnabled() + + +Type_NA = _Type_NA +Type_BOOL = _Type_BOOL +Type_UINT8 = _Type_UINT8 +Type_INT8 = _Type_INT8 +Type_UINT16 = _Type_UINT16 +Type_INT16 = _Type_INT16 +Type_UINT32 = _Type_UINT32 +Type_INT32 = _Type_INT32 +Type_UINT64 = _Type_UINT64 +Type_INT64 = _Type_INT64 +Type_HALF_FLOAT = _Type_HALF_FLOAT +Type_FLOAT = _Type_FLOAT +Type_DOUBLE = _Type_DOUBLE +Type_DECIMAL32 = _Type_DECIMAL32 +Type_DECIMAL64 = _Type_DECIMAL64 +Type_DECIMAL128 = _Type_DECIMAL128 +Type_DECIMAL256 = _Type_DECIMAL256 +Type_DATE32 = _Type_DATE32 +Type_DATE64 = _Type_DATE64 +Type_TIMESTAMP = _Type_TIMESTAMP +Type_TIME32 = _Type_TIME32 +Type_TIME64 = _Type_TIME64 +Type_DURATION = _Type_DURATION +Type_INTERVAL_MONTH_DAY_NANO = _Type_INTERVAL_MONTH_DAY_NANO +Type_BINARY = _Type_BINARY +Type_STRING = _Type_STRING +Type_LARGE_BINARY = _Type_LARGE_BINARY +Type_LARGE_STRING = _Type_LARGE_STRING +Type_FIXED_SIZE_BINARY = _Type_FIXED_SIZE_BINARY +Type_BINARY_VIEW = _Type_BINARY_VIEW +Type_STRING_VIEW = _Type_STRING_VIEW +Type_LIST = _Type_LIST +Type_LARGE_LIST = _Type_LARGE_LIST +Type_LIST_VIEW = _Type_LIST_VIEW +Type_LARGE_LIST_VIEW = _Type_LARGE_LIST_VIEW +Type_MAP = _Type_MAP +Type_FIXED_SIZE_LIST = _Type_FIXED_SIZE_LIST +Type_STRUCT = _Type_STRUCT +Type_SPARSE_UNION = _Type_SPARSE_UNION +Type_DENSE_UNION = _Type_DENSE_UNION +Type_DICTIONARY = _Type_DICTIONARY +Type_RUN_END_ENCODED = _Type_RUN_END_ENCODED + +UnionMode_SPARSE = _UnionMode_SPARSE +UnionMode_DENSE = _UnionMode_DENSE + +__pc = None +__pac = None +__cuda_loaded = None + + +def _pc(): + global __pc + if __pc is None: + import pyarrow.compute as pc + __pc = pc + return __pc + + +def _pac(): + global __pac + if __pac is None: + import pyarrow.acero as pac + __pac = pac + return __pac + + +def _ensure_cuda_loaded(): + # Try importing the cuda module to ensure libarrow_cuda gets loaded + # to register the CUDA device for the C Data Interface import + global __cuda_loaded + if __cuda_loaded is None: + try: + import pyarrow.cuda # no-cython-lint + __cuda_loaded = True + except ImportError as exc: + __cuda_loaded = str(exc) + + if __cuda_loaded is not True: + raise ImportError( + "Trying to import data on a CUDA device, but PyArrow is not built with " + f"CUDA support.\n(importing 'pyarrow.cuda' resulted in \"{__cuda_loaded}\")." + ) + + +def _gdb_test_session(): + GdbTestSession() + + +# Assorted compatibility helpers +include "compat.pxi" + +# Exception types and Status handling +include "error.pxi" + +# Configuration information +include "config.pxi" + +# pandas API shim +include "pandas-shim.pxi" + +# Memory pools and allocation +include "memory.pxi" + +# Device type and memory manager +include "device.pxi" + +# DataType, Field, Schema +include "types.pxi" + +# Array scalar values +include "scalar.pxi" + +# Array types +include "array.pxi" + +# Builders +include "builder.pxi" + +# Column, Table, Record Batch +include "table.pxi" + +# Tensors +include "tensor.pxi" + +# DLPack +include "_dlpack.pxi" + +# File IO +include "io.pxi" + +# IPC / Messaging +include "ipc.pxi" + +# Micro-benchmark routines +include "benchmark.pxi" + +# Public API +include "public-api.pxi" diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/libarrow_python_parquet_encryption.so.1900 b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/libarrow_python_parquet_encryption.so.1900 new file mode 100644 index 0000000000000000000000000000000000000000..cd9621415c785a581bf777beafe5ed2745331921 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/libarrow_python_parquet_encryption.so.1900 differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/memory.pxi b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/memory.pxi new file mode 100644 index 0000000000000000000000000000000000000000..1ddcb01ccb6ab2ca84786e6e60a5f4c4ffbfc5bd --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/memory.pxi @@ -0,0 +1,274 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# cython: profile=False +# distutils: language = c++ +# cython: embedsignature = True + + +cdef class MemoryPool(_Weakrefable): + """ + Base class for memory allocation. + + Besides tracking its number of allocated bytes, a memory pool also + takes care of the required 64-byte alignment for Arrow data. + """ + + def __init__(self): + raise TypeError("Do not call {}'s constructor directly, " + "use pyarrow.*_memory_pool instead." + .format(self.__class__.__name__)) + + cdef void init(self, CMemoryPool* pool): + self.pool = pool + + def release_unused(self): + """ + Attempt to return to the OS any memory being held onto by the pool. + + This function should not be called except potentially for + benchmarking or debugging as it could be expensive and detrimental to + performance. + + This is best effort and may not have any effect on some memory pools + or in some situations (e.g. fragmentation). + """ + cdef CMemoryPool* pool = c_get_memory_pool() + with nogil: + pool.ReleaseUnused() + + def bytes_allocated(self): + """ + Return the number of bytes that are currently allocated from this + memory pool. + """ + return self.pool.bytes_allocated() + + def max_memory(self): + """ + Return the peak memory allocation in this memory pool. + This can be an approximate number in multi-threaded applications. + + None is returned if the pool implementation doesn't know how to + compute this number. + """ + ret = self.pool.max_memory() + return ret if ret >= 0 else None + + @property + def backend_name(self): + """ + The name of the backend used by this MemoryPool (e.g. "jemalloc"). + """ + return frombytes(self.pool.backend_name()) + + def __repr__(self): + name = f"pyarrow.{self.__class__.__name__}" + return (f"<{name} " + f"backend_name={self.backend_name} " + f"bytes_allocated={self.bytes_allocated()} " + f"max_memory={self.max_memory()}>") + +cdef CMemoryPool* maybe_unbox_memory_pool(MemoryPool memory_pool): + if memory_pool is None: + return c_get_memory_pool() + else: + return memory_pool.pool + + +cdef api object box_memory_pool(CMemoryPool *c_pool): + cdef MemoryPool pool = MemoryPool.__new__(MemoryPool) + pool.init(c_pool) + return pool + + +cdef class LoggingMemoryPool(MemoryPool): + cdef: + unique_ptr[CLoggingMemoryPool] logging_pool + + def __init__(self): + raise TypeError("Do not call {}'s constructor directly, " + "use pyarrow.logging_memory_pool instead." + .format(self.__class__.__name__)) + + +cdef class ProxyMemoryPool(MemoryPool): + """ + Memory pool implementation that tracks the number of bytes and + maximum memory allocated through its direct calls, while redirecting + to another memory pool. + """ + cdef: + unique_ptr[CProxyMemoryPool] proxy_pool + + def __init__(self): + raise TypeError("Do not call {}'s constructor directly, " + "use pyarrow.proxy_memory_pool instead." + .format(self.__class__.__name__)) + + +def default_memory_pool(): + """ + Return the process-global memory pool. + + Examples + -------- + >>> default_memory_pool() + + """ + cdef: + MemoryPool pool = MemoryPool.__new__(MemoryPool) + pool.init(c_get_memory_pool()) + return pool + + +def proxy_memory_pool(MemoryPool parent): + """ + Create and return a MemoryPool instance that redirects to the + *parent*, but with separate allocation statistics. + + Parameters + ---------- + parent : MemoryPool + The real memory pool that should be used for allocations. + """ + cdef ProxyMemoryPool out = ProxyMemoryPool.__new__(ProxyMemoryPool) + out.proxy_pool.reset(new CProxyMemoryPool(parent.pool)) + out.init(out.proxy_pool.get()) + return out + + +def logging_memory_pool(MemoryPool parent): + """ + Create and return a MemoryPool instance that redirects to the + *parent*, but also dumps allocation logs on stderr. + + Parameters + ---------- + parent : MemoryPool + The real memory pool that should be used for allocations. + """ + cdef LoggingMemoryPool out = LoggingMemoryPool.__new__( + LoggingMemoryPool, parent) + out.logging_pool.reset(new CLoggingMemoryPool(parent.pool)) + out.init(out.logging_pool.get()) + return out + + +def system_memory_pool(): + """ + Return a memory pool based on the C malloc heap. + """ + cdef: + MemoryPool pool = MemoryPool.__new__(MemoryPool) + pool.init(c_system_memory_pool()) + return pool + + +def jemalloc_memory_pool(): + """ + Return a memory pool based on the jemalloc heap. + + NotImplementedError is raised if jemalloc support is not enabled. + """ + cdef: + CMemoryPool* c_pool + MemoryPool pool = MemoryPool.__new__(MemoryPool) + check_status(c_jemalloc_memory_pool(&c_pool)) + pool.init(c_pool) + return pool + + +def mimalloc_memory_pool(): + """ + Return a memory pool based on the mimalloc heap. + + NotImplementedError is raised if mimalloc support is not enabled. + """ + cdef: + CMemoryPool* c_pool + MemoryPool pool = MemoryPool.__new__(MemoryPool) + check_status(c_mimalloc_memory_pool(&c_pool)) + pool.init(c_pool) + return pool + + +def set_memory_pool(MemoryPool pool): + """ + Set the default memory pool. + + Parameters + ---------- + pool : MemoryPool + The memory pool that should be used by default. + """ + c_set_default_memory_pool(pool.pool) + + +cdef MemoryPool _default_memory_pool = default_memory_pool() +cdef LoggingMemoryPool _logging_memory_pool = logging_memory_pool( + _default_memory_pool) + + +def log_memory_allocations(enable=True): + """ + Enable or disable memory allocator logging for debugging purposes + + Parameters + ---------- + enable : bool, default True + Pass False to disable logging + """ + if enable: + set_memory_pool(_logging_memory_pool) + else: + set_memory_pool(_default_memory_pool) + + +def total_allocated_bytes(): + """ + Return the currently allocated bytes from the default memory pool. + Other memory pools may not be accounted for. + """ + cdef CMemoryPool* pool = c_get_memory_pool() + return pool.bytes_allocated() + + +def jemalloc_set_decay_ms(decay_ms): + """ + Set arenas.dirty_decay_ms and arenas.muzzy_decay_ms to indicated number of + milliseconds. A value of 0 (the default) results in dirty / muzzy memory + pages being released right away to the OS, while a higher value will result + in a time-based decay. See the jemalloc docs for more information + + It's best to set this at the start of your application. + + Parameters + ---------- + decay_ms : int + Number of milliseconds to set for jemalloc decay conf parameters. Note + that this change will only affect future memory arenas + """ + check_status(c_jemalloc_set_decay_ms(decay_ms)) + + +def supported_memory_backends(): + """ + Return a list of available memory pool backends + """ + cdef vector[c_string] backends = c_supported_memory_backends() + return [backend.decode() for backend in backends] diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/pandas-shim.pxi b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/pandas-shim.pxi new file mode 100644 index 0000000000000000000000000000000000000000..18de584bff835994a2db90af89a0c79a9fe37d97 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/pandas-shim.pxi @@ -0,0 +1,281 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# pandas lazy-loading API shim that reduces API call and import overhead + +import warnings +from threading import Lock + + +cdef class _PandasAPIShim(object): + """ + Lazy pandas importer that isolates usages of pandas APIs and avoids + importing pandas until it's actually needed + """ + cdef: + bint _tried_importing_pandas + bint _have_pandas + + cdef readonly: + object _loose_version, _version + object _pd, _types_api, _compat_module + object _data_frame, _index, _series, _categorical_type + object _datetimetz_type, _extension_array, _extension_dtype + object _array_like_types, _is_extension_array_dtype, _lock + bint has_sparse + bint _pd024 + bint _is_v1, _is_ge_v21, _is_ge_v23, _is_ge_v3, _is_ge_v3_strict + + def __init__(self): + self._lock = Lock() + self._tried_importing_pandas = False + self._have_pandas = 0 + + cdef _import_pandas(self, bint raise_): + try: + import pandas as pd + import pyarrow.pandas_compat as pdcompat + except ImportError: + self._have_pandas = False + if raise_: + raise + else: + return + + from pyarrow.vendored.version import Version + + self._pd = pd + self._version = pd.__version__ + self._loose_version = Version(pd.__version__) + self._is_v1 = False + + if self._loose_version < Version('1.0.0'): + self._have_pandas = False + if raise_: + raise ImportError( + "pyarrow requires pandas 1.0.0 or above, pandas {} is " + "installed".format(self._version) + ) + else: + warnings.warn( + "pyarrow requires pandas 1.0.0 or above, pandas {} is " + "installed. Therefore, pandas-specific integration is not " + "used.".format(self._version), stacklevel=2) + return + + self._is_v1 = self._loose_version < Version('2.0.0') + self._is_ge_v21 = self._loose_version >= Version('2.1.0') + self._is_ge_v23 = self._loose_version >= Version('2.3.0.dev0') + self._is_ge_v3 = self._loose_version >= Version('3.0.0.dev0') + self._is_ge_v3_strict = self._loose_version >= Version('3.0.0') + + self._compat_module = pdcompat + self._data_frame = pd.DataFrame + self._index = pd.Index + self._categorical_type = pd.Categorical + self._series = pd.Series + self._extension_array = pd.api.extensions.ExtensionArray + self._array_like_types = ( + self._series, self._index, self._categorical_type, + self._extension_array) + self._extension_dtype = pd.api.extensions.ExtensionDtype + self._is_extension_array_dtype = ( + pd.api.types.is_extension_array_dtype) + self._types_api = pd.api.types + self._datetimetz_type = pd.api.types.DatetimeTZDtype + self._have_pandas = True + self.has_sparse = False + + cdef inline _check_import(self, bint raise_=True): + if not self._tried_importing_pandas: + with self._lock: + if not self._tried_importing_pandas: + try: + self._import_pandas(raise_) + finally: + self._tried_importing_pandas = True + return + + if not self._have_pandas and raise_: + self._import_pandas(raise_) + + def series(self, *args, **kwargs): + self._check_import() + return self._series(*args, **kwargs) + + def data_frame(self, *args, **kwargs): + self._check_import() + return self._data_frame(*args, **kwargs) + + cdef inline bint _have_pandas_internal(self): + if not self._tried_importing_pandas: + self._check_import(raise_=False) + return self._have_pandas + + @property + def have_pandas(self): + return self._have_pandas_internal() + + @property + def compat(self): + self._check_import() + return self._compat_module + + @property + def pd(self): + self._check_import() + return self._pd + + cpdef infer_dtype(self, obj): + self._check_import() + try: + return self._types_api.infer_dtype(obj, skipna=False) + except AttributeError: + return self._pd.lib.infer_dtype(obj) + + cpdef pandas_dtype(self, dtype): + self._check_import() + try: + return self._types_api.pandas_dtype(dtype) + except AttributeError: + return None + + @property + def loose_version(self): + self._check_import() + return self._loose_version + + @property + def version(self): + self._check_import() + return self._version + + def is_v1(self): + self._check_import() + return self._is_v1 + + def is_ge_v21(self): + self._check_import() + return self._is_ge_v21 + + def is_ge_v23(self): + self._check_import() + return self._is_ge_v23 + + def is_ge_v3(self): + self._check_import() + return self._is_ge_v3 + + def is_ge_v3_strict(self): + self._check_import() + return self._is_ge_v3_strict + + def uses_string_dtype(self): + if self.is_ge_v3_strict(): + return True + try: + if self.is_ge_v23() and self.pd.options.future.infer_string: + return True + except: + pass + return False + + @property + def categorical_type(self): + self._check_import() + return self._categorical_type + + @property + def datetimetz_type(self): + self._check_import() + return self._datetimetz_type + + @property + def extension_dtype(self): + self._check_import() + return self._extension_dtype + + cpdef is_array_like(self, obj): + self._check_import() + return isinstance(obj, self._array_like_types) + + cpdef is_categorical(self, obj): + if self._have_pandas_internal(): + return isinstance(obj, self._categorical_type) + else: + return False + + cpdef is_datetimetz(self, obj): + if self._have_pandas_internal(): + return isinstance(obj, self._datetimetz_type) + else: + return False + + cpdef is_extension_array_dtype(self, obj): + self._check_import() + if self._is_extension_array_dtype: + return self._is_extension_array_dtype(obj) + else: + return False + + cpdef is_sparse(self, obj): + if self._have_pandas_internal(): + return isinstance(obj.dtype, self.pd.SparseDtype) + else: + return False + + cpdef is_data_frame(self, obj): + if self._have_pandas_internal(): + return isinstance(obj, self._data_frame) + else: + return False + + cpdef is_series(self, obj): + if self._have_pandas_internal(): + return isinstance(obj, self._series) + else: + return False + + cpdef is_index(self, obj): + if self._have_pandas_internal(): + return isinstance(obj, self._index) + else: + return False + + cpdef get_values(self, obj): + """ + Get the underlying array values of a pandas Series or Index in the + format (np.ndarray or pandas ExtensionArray) as we need them. + + Assumes obj is a pandas Series or Index. + """ + self._check_import() + if isinstance(obj.dtype, (self.pd.api.types.IntervalDtype, + self.pd.api.types.PeriodDtype)): + return obj.array + return obj.values + + def get_rangeindex_attribute(self, level, name): + # public start/stop/step attributes added in pandas 0.25.0 + self._check_import() + if hasattr(level, name): + return getattr(level, name) + return getattr(level, '_' + name) + + +cdef _PandasAPIShim pandas_api = _PandasAPIShim() +_pandas_api = pandas_api diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/scalar.pxi b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/scalar.pxi new file mode 100644 index 0000000000000000000000000000000000000000..2235cd0b981a673f206cedf97722fcb97843dc50 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/scalar.pxi @@ -0,0 +1,1305 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import collections +from cython cimport binding +from uuid import UUID + + +cdef class Scalar(_Weakrefable): + """ + The base class for scalars. + """ + + def __init__(self): + raise TypeError("Do not call {}'s constructor directly, use " + "pa.scalar() instead.".format(self.__class__.__name__)) + + cdef void init(self, const shared_ptr[CScalar]& wrapped): + self.wrapped = wrapped + + @staticmethod + cdef wrap(const shared_ptr[CScalar]& wrapped): + cdef: + Scalar self + Type type_id = wrapped.get().type.get().id() + shared_ptr[CDataType] sp_data_type = wrapped.get().type + + if type_id == _Type_NA: + return _NULL + + if type_id not in _scalar_classes: + raise NotImplementedError( + "Wrapping scalar of type " + frombytes(sp_data_type.get().ToString())) + + typ = get_scalar_class_from_type(sp_data_type) + self = typ.__new__(typ) + self.init(wrapped) + + return self + + cdef inline shared_ptr[CScalar] unwrap(self) nogil: + return self.wrapped + + @property + def type(self): + """ + Data type of the Scalar object. + """ + return pyarrow_wrap_data_type(self.wrapped.get().type) + + @property + def is_valid(self): + """ + Holds a valid (non-null) value. + """ + return self.wrapped.get().is_valid + + def cast(self, object target_type=None, safe=None, options=None, memory_pool=None): + """ + Cast scalar value to another data type. + + See :func:`pyarrow.compute.cast` for usage. + + Parameters + ---------- + target_type : DataType, default None + Type to cast scalar to. + safe : boolean, default True + Whether to check for conversion errors such as overflow. + options : CastOptions, default None + Additional checks pass by CastOptions + memory_pool : MemoryPool, optional + memory pool to use for allocations during function execution. + + Returns + ------- + scalar : A Scalar of the given target data type. + """ + return _pc().cast(self, target_type, safe=safe, + options=options, memory_pool=memory_pool) + + def validate(self, *, full=False): + """ + Perform validation checks. An exception is raised if validation fails. + + By default only cheap validation checks are run. Pass `full=True` + for thorough validation checks (potentially O(n)). + + Parameters + ---------- + full : bool, default False + If True, run expensive checks, otherwise cheap checks only. + + Raises + ------ + ArrowInvalid + """ + if full: + with nogil: + check_status(self.wrapped.get().ValidateFull()) + else: + with nogil: + check_status(self.wrapped.get().Validate()) + + def __repr__(self): + return ''.format( + self.__class__.__name__, self.as_py() + ) + + def __str__(self): + return str(self.as_py()) + + def equals(self, Scalar other not None): + """ + Parameters + ---------- + other : pyarrow.Scalar + + Returns + ------- + bool + """ + return self.wrapped.get().Equals(other.unwrap().get()[0]) + + def __eq__(self, other): + try: + return self.equals(other) + except TypeError: + return NotImplemented + + def __hash__(self): + cdef CScalarHash hasher + return hasher(self.wrapped) + + def __reduce__(self): + return scalar, (self.as_py(), self.type) + + def as_py(self): + raise NotImplementedError() + + +_NULL = NA = None + + +cdef class NullScalar(Scalar): + """ + Concrete class for null scalars. + """ + + def __cinit__(self): + global NA + if NA is not None: + raise RuntimeError('Cannot create multiple NullScalar instances') + self.init(shared_ptr[CScalar](new CNullScalar())) + + def __init__(self): + pass + + def as_py(self): + """ + Return this value as a Python None. + """ + return None + + +_NULL = NA = NullScalar() + + +cdef class BooleanScalar(Scalar): + """ + Concrete class for boolean scalars. + """ + + def as_py(self): + """ + Return this value as a Python bool. + """ + cdef CBooleanScalar* sp = self.wrapped.get() + return sp.value if sp.is_valid else None + + +cdef class UInt8Scalar(Scalar): + """ + Concrete class for uint8 scalars. + """ + + def as_py(self): + """ + Return this value as a Python int. + """ + cdef CUInt8Scalar* sp = self.wrapped.get() + return sp.value if sp.is_valid else None + + +cdef class Int8Scalar(Scalar): + """ + Concrete class for int8 scalars. + """ + + def as_py(self): + """ + Return this value as a Python int. + """ + cdef CInt8Scalar* sp = self.wrapped.get() + return sp.value if sp.is_valid else None + + +cdef class UInt16Scalar(Scalar): + """ + Concrete class for uint16 scalars. + """ + + def as_py(self): + """ + Return this value as a Python int. + """ + cdef CUInt16Scalar* sp = self.wrapped.get() + return sp.value if sp.is_valid else None + + +cdef class Int16Scalar(Scalar): + """ + Concrete class for int16 scalars. + """ + + def as_py(self): + """ + Return this value as a Python int. + """ + cdef CInt16Scalar* sp = self.wrapped.get() + return sp.value if sp.is_valid else None + + +cdef class UInt32Scalar(Scalar): + """ + Concrete class for uint32 scalars. + """ + + def as_py(self): + """ + Return this value as a Python int. + """ + cdef CUInt32Scalar* sp = self.wrapped.get() + return sp.value if sp.is_valid else None + + +cdef class Int32Scalar(Scalar): + """ + Concrete class for int32 scalars. + """ + + def as_py(self): + """ + Return this value as a Python int. + """ + cdef CInt32Scalar* sp = self.wrapped.get() + return sp.value if sp.is_valid else None + + +cdef class UInt64Scalar(Scalar): + """ + Concrete class for uint64 scalars. + """ + + def as_py(self): + """ + Return this value as a Python int. + """ + cdef CUInt64Scalar* sp = self.wrapped.get() + return sp.value if sp.is_valid else None + + +cdef class Int64Scalar(Scalar): + """ + Concrete class for int64 scalars. + """ + + def as_py(self): + """ + Return this value as a Python int. + """ + cdef CInt64Scalar* sp = self.wrapped.get() + return sp.value if sp.is_valid else None + + +cdef class HalfFloatScalar(Scalar): + """ + Concrete class for float scalars. + """ + + def as_py(self): + """ + Return this value as a Python float. + """ + cdef CHalfFloatScalar* sp = self.wrapped.get() + return PyHalf_FromHalf(sp.value) if sp.is_valid else None + + +cdef class FloatScalar(Scalar): + """ + Concrete class for float scalars. + """ + + def as_py(self): + """ + Return this value as a Python float. + """ + cdef CFloatScalar* sp = self.wrapped.get() + return sp.value if sp.is_valid else None + + +cdef class DoubleScalar(Scalar): + """ + Concrete class for double scalars. + """ + + def as_py(self): + """ + Return this value as a Python float. + """ + cdef CDoubleScalar* sp = self.wrapped.get() + return sp.value if sp.is_valid else None + + +cdef class Decimal32Scalar(Scalar): + """ + Concrete class for decimal32 scalars. + """ + + def as_py(self): + """ + Return this value as a Python Decimal. + """ + cdef: + CDecimal32Scalar* sp = self.wrapped.get() + CDecimal32Type* dtype = sp.type.get() + if sp.is_valid: + return _pydecimal.Decimal( + frombytes(sp.value.ToString(dtype.scale())) + ) + else: + return None + + +cdef class Decimal64Scalar(Scalar): + """ + Concrete class for decimal64 scalars. + """ + + def as_py(self): + """ + Return this value as a Python Decimal. + """ + cdef: + CDecimal64Scalar* sp = self.wrapped.get() + CDecimal64Type* dtype = sp.type.get() + if sp.is_valid: + return _pydecimal.Decimal( + frombytes(sp.value.ToString(dtype.scale())) + ) + else: + return None + + +cdef class Decimal128Scalar(Scalar): + """ + Concrete class for decimal128 scalars. + """ + + def as_py(self): + """ + Return this value as a Python Decimal. + """ + cdef: + CDecimal128Scalar* sp = self.wrapped.get() + CDecimal128Type* dtype = sp.type.get() + if sp.is_valid: + return _pydecimal.Decimal( + frombytes(sp.value.ToString(dtype.scale())) + ) + else: + return None + + +cdef class Decimal256Scalar(Scalar): + """ + Concrete class for decimal256 scalars. + """ + + def as_py(self): + """ + Return this value as a Python Decimal. + """ + cdef: + CDecimal256Scalar* sp = self.wrapped.get() + CDecimal256Type* dtype = sp.type.get() + if sp.is_valid: + return _pydecimal.Decimal( + frombytes(sp.value.ToString(dtype.scale())) + ) + else: + return None + + +cdef class Date32Scalar(Scalar): + """ + Concrete class for date32 scalars. + """ + + @property + def value(self): + cdef CDate32Scalar* sp = self.wrapped.get() + return sp.value if sp.is_valid else None + + def as_py(self): + """ + Return this value as a Python datetime.datetime instance. + """ + cdef CDate32Scalar* sp = self.wrapped.get() + + if sp.is_valid: + # shift to seconds since epoch + return ( + datetime.date(1970, 1, 1) + datetime.timedelta(days=sp.value) + ) + else: + return None + + +cdef class Date64Scalar(Scalar): + """ + Concrete class for date64 scalars. + """ + + @property + def value(self): + cdef CDate64Scalar* sp = self.wrapped.get() + return sp.value if sp.is_valid else None + + def as_py(self): + """ + Return this value as a Python datetime.datetime instance. + """ + cdef CDate64Scalar* sp = self.wrapped.get() + + if sp.is_valid: + return ( + datetime.date(1970, 1, 1) + + datetime.timedelta(days=sp.value / 86400000) + ) + else: + return None + + +def _datetime_from_int(int64_t value, TimeUnit unit, tzinfo=None): + if unit == TimeUnit_SECOND: + delta = datetime.timedelta(seconds=value) + elif unit == TimeUnit_MILLI: + delta = datetime.timedelta(milliseconds=value) + elif unit == TimeUnit_MICRO: + delta = datetime.timedelta(microseconds=value) + else: + # TimeUnit_NANO: prefer pandas timestamps if available + if _pandas_api.have_pandas: + return _pandas_api.pd.Timestamp(value, tz=tzinfo, unit='ns') + # otherwise safely truncate to microsecond resolution datetime + if value % 1000 != 0: + raise ValueError( + "Nanosecond resolution temporal type {} is not safely " + "convertible to microseconds to convert to datetime.datetime. " + "Install pandas to return as Timestamp with nanosecond " + "support or access the .value attribute.".format(value) + ) + delta = datetime.timedelta(microseconds=value // 1000) + + dt = datetime.datetime(1970, 1, 1) + delta + # adjust timezone if set to the datatype + if tzinfo is not None: + dt = dt.replace(tzinfo=datetime.timezone.utc).astimezone(tzinfo) + + return dt + + +cdef class Time32Scalar(Scalar): + """ + Concrete class for time32 scalars. + """ + + @property + def value(self): + cdef CTime32Scalar* sp = self.wrapped.get() + return sp.value if sp.is_valid else None + + def as_py(self): + """ + Return this value as a Python datetime.timedelta instance. + """ + cdef: + CTime32Scalar* sp = self.wrapped.get() + CTime32Type* dtype = sp.type.get() + + if sp.is_valid: + return _datetime_from_int(sp.value, unit=dtype.unit()).time() + else: + return None + + +cdef class Time64Scalar(Scalar): + """ + Concrete class for time64 scalars. + """ + + @property + def value(self): + cdef CTime64Scalar* sp = self.wrapped.get() + return sp.value if sp.is_valid else None + + def as_py(self): + """ + Return this value as a Python datetime.timedelta instance. + """ + cdef: + CTime64Scalar* sp = self.wrapped.get() + CTime64Type* dtype = sp.type.get() + + if sp.is_valid: + return _datetime_from_int(sp.value, unit=dtype.unit()).time() + else: + return None + + +cdef class TimestampScalar(Scalar): + """ + Concrete class for timestamp scalars. + """ + + @property + def value(self): + cdef CTimestampScalar* sp = self.wrapped.get() + return sp.value if sp.is_valid else None + + def as_py(self): + """ + Return this value as a Pandas Timestamp instance (if units are + nanoseconds and pandas is available), otherwise as a Python + datetime.datetime instance. + """ + cdef: + CTimestampScalar* sp = self.wrapped.get() + CTimestampType* dtype = sp.type.get() + + if not sp.is_valid: + return None + + if not dtype.timezone().empty(): + tzinfo = string_to_tzinfo(frombytes(dtype.timezone())) + else: + tzinfo = None + + return _datetime_from_int(sp.value, unit=dtype.unit(), tzinfo=tzinfo) + + def __repr__(self): + """ + Return the representation of TimestampScalar using `strftime` to avoid + original repr datetime values being out of range. + """ + cdef: + CTimestampScalar* sp = self.wrapped.get() + CTimestampType* dtype = sp.type.get() + + if not dtype.timezone().empty(): + type_format = str(_pc().strftime(self, format="%Y-%m-%dT%H:%M:%S%z")) + else: + type_format = str(_pc().strftime(self)) + return ''.format( + self.__class__.__name__, type_format + ) + + +cdef class DurationScalar(Scalar): + """ + Concrete class for duration scalars. + """ + + @property + def value(self): + cdef CDurationScalar* sp = self.wrapped.get() + return sp.value if sp.is_valid else None + + def as_py(self): + """ + Return this value as a Pandas Timedelta instance (if units are + nanoseconds and pandas is available), otherwise as a Python + datetime.timedelta instance. + """ + cdef: + CDurationScalar* sp = self.wrapped.get() + CDurationType* dtype = sp.type.get() + TimeUnit unit = dtype.unit() + + if not sp.is_valid: + return None + + if unit == TimeUnit_SECOND: + return datetime.timedelta(seconds=sp.value) + elif unit == TimeUnit_MILLI: + return datetime.timedelta(milliseconds=sp.value) + elif unit == TimeUnit_MICRO: + return datetime.timedelta(microseconds=sp.value) + else: + # TimeUnit_NANO: prefer pandas timestamps if available + if _pandas_api.have_pandas: + return _pandas_api.pd.Timedelta(sp.value, unit='ns') + # otherwise safely truncate to microsecond resolution timedelta + if sp.value % 1000 != 0: + raise ValueError( + "Nanosecond duration {} is not safely convertible to " + "microseconds to convert to datetime.timedelta. Install " + "pandas to return as Timedelta with nanosecond support or " + "access the .value attribute.".format(sp.value) + ) + return datetime.timedelta(microseconds=sp.value // 1000) + + +cdef class MonthDayNanoIntervalScalar(Scalar): + """ + Concrete class for month, day, nanosecond interval scalars. + """ + + @property + def value(self): + """ + Same as self.as_py() + """ + return self.as_py() + + def as_py(self): + """ + Return this value as a pyarrow.MonthDayNano. + """ + cdef: + PyObject* val + CMonthDayNanoIntervalScalar* scalar + scalar = self.wrapped.get() + val = GetResultValue(MonthDayNanoIntervalScalarToPyObject( + deref(scalar))) + return PyObject_to_object(val) + + +cdef class BinaryScalar(Scalar): + """ + Concrete class for binary-like scalars. + """ + + def as_buffer(self): + """ + Return a view over this value as a Buffer object. + """ + cdef CBaseBinaryScalar* sp = self.wrapped.get() + return pyarrow_wrap_buffer(sp.value) if sp.is_valid else None + + def as_py(self): + """ + Return this value as a Python bytes. + """ + buffer = self.as_buffer() + return None if buffer is None else buffer.to_pybytes() + + +cdef class LargeBinaryScalar(BinaryScalar): + pass + + +cdef class FixedSizeBinaryScalar(BinaryScalar): + pass + + +cdef class StringScalar(BinaryScalar): + """ + Concrete class for string-like (utf8) scalars. + """ + + def as_py(self): + """ + Return this value as a Python string. + """ + buffer = self.as_buffer() + return None if buffer is None else str(buffer, 'utf8') + + +cdef class LargeStringScalar(StringScalar): + pass + + +cdef class BinaryViewScalar(BinaryScalar): + pass + + +cdef class StringViewScalar(StringScalar): + pass + + +cdef class ListScalar(Scalar): + """ + Concrete class for list-like scalars. + """ + + @property + def values(self): + cdef CBaseListScalar* sp = self.wrapped.get() + if sp.is_valid: + return pyarrow_wrap_array(sp.value) + else: + return None + + def __len__(self): + """ + Return the number of values. + """ + return len(self.values) + + def __getitem__(self, i): + """ + Return the value at the given index. + """ + return self.values[_normalize_index(i, len(self))] + + def __iter__(self): + """ + Iterate over this element's values. + """ + return iter(self.values) + + def as_py(self): + """ + Return this value as a Python list. + """ + arr = self.values + return None if arr is None else arr.to_pylist() + + +cdef class FixedSizeListScalar(ListScalar): + pass + + +cdef class LargeListScalar(ListScalar): + pass + + +cdef class ListViewScalar(ListScalar): + pass + + +cdef class LargeListViewScalar(ListScalar): + pass + + +cdef class StructScalar(Scalar, collections.abc.Mapping): + """ + Concrete class for struct scalars. + """ + + def __len__(self): + cdef CStructScalar* sp = self.wrapped.get() + return sp.value.size() + + def __iter__(self): + cdef: + CStructScalar* sp = self.wrapped.get() + CStructType* dtype = sp.type.get() + vector[shared_ptr[CField]] fields = dtype.fields() + + for i in range(dtype.num_fields()): + yield frombytes(fields[i].get().name()) + + def items(self): + return ((key, self[i]) for i, key in enumerate(self)) + + def __contains__(self, key): + return key in list(self) + + def __getitem__(self, key): + """ + Return the child value for the given field. + + Parameters + ---------- + index : Union[int, str] + Index / position or name of the field. + + Returns + ------- + result : Scalar + """ + cdef: + CFieldRef ref + CStructScalar* sp = self.wrapped.get() + + if isinstance(key, (bytes, str)): + ref = CFieldRef( tobytes(key)) + elif isinstance(key, int): + ref = CFieldRef( key) + else: + raise TypeError('Expected integer or string index') + + try: + return Scalar.wrap(GetResultValue(sp.field(ref))) + except ArrowInvalid as exc: + if isinstance(key, int): + raise IndexError(key) from exc + else: + raise KeyError(key) from exc + + def as_py(self): + """ + Return this value as a Python dict. + """ + if self.is_valid: + try: + return {k: self[k].as_py() for k in self.keys()} + except KeyError: + raise ValueError( + "Converting to Python dictionary is not supported when " + "duplicate field names are present") + else: + return None + + def _as_py_tuple(self): + # a version that returns a tuple instead of dict to support repr/str + # with the presence of duplicate field names + if self.is_valid: + return [(key, self[i].as_py()) for i, key in enumerate(self)] + else: + return None + + def __repr__(self): + return ''.format( + self.__class__.__name__, self._as_py_tuple() + ) + + def __str__(self): + return str(self._as_py_tuple()) + + +cdef class MapScalar(ListScalar): + """ + Concrete class for map scalars. + """ + + def __getitem__(self, i): + """ + Return the value at the given index. + """ + arr = self.values + if arr is None: + raise IndexError(i) + dct = arr[_normalize_index(i, len(arr))] + return (dct[self.type.key_field.name], dct[self.type.item_field.name]) + + def __iter__(self): + """ + Iterate over this element's values. + """ + arr = self.values + if arr is None: + return + for k, v in zip(arr.field(self.type.key_field.name), arr.field(self.type.item_field.name)): + yield (k.as_py(), v.as_py()) + + def as_py(self): + """ + Return this value as a Python list. + """ + cdef CStructScalar* sp = self.wrapped.get() + return list(self) if sp.is_valid else None + + +cdef class DictionaryScalar(Scalar): + """ + Concrete class for dictionary-encoded scalars. + """ + + @staticmethod + @binding(True) # Required for cython < 3 + def _reconstruct(type, is_valid, index, dictionary): + cdef: + CDictionaryScalarIndexAndDictionary value + shared_ptr[CDictionaryScalar] wrapped + DataType type_ + Scalar index_ + Array dictionary_ + + type_ = ensure_type(type, allow_none=False) + if not isinstance(type_, DictionaryType): + raise TypeError('Must pass a DictionaryType instance') + + if isinstance(index, Scalar): + if not index.type.equals(type.index_type): + raise TypeError("The Scalar value passed as index must have " + "identical type to the dictionary type's " + "index_type") + index_ = index + else: + index_ = scalar(index, type=type_.index_type) + + if isinstance(dictionary, Array): + if not dictionary.type.equals(type.value_type): + raise TypeError("The Array passed as dictionary must have " + "identical type to the dictionary type's " + "value_type") + dictionary_ = dictionary + else: + dictionary_ = array(dictionary, type=type_.value_type) + + value.index = pyarrow_unwrap_scalar(index_) + value.dictionary = pyarrow_unwrap_array(dictionary_) + + wrapped = make_shared[CDictionaryScalar]( + value, pyarrow_unwrap_data_type(type_), (is_valid) + ) + return Scalar.wrap( wrapped) + + def __reduce__(self): + return DictionaryScalar._reconstruct, ( + self.type, self.is_valid, self.index, self.dictionary + ) + + @property + def index(self): + """ + Return this value's underlying index as a scalar. + """ + cdef CDictionaryScalar* sp = self.wrapped.get() + return Scalar.wrap(sp.value.index) + + @property + def value(self): + """ + Return the encoded value as a scalar. + """ + cdef CDictionaryScalar* sp = self.wrapped.get() + return Scalar.wrap(GetResultValue(sp.GetEncodedValue())) + + @property + def dictionary(self): + cdef CDictionaryScalar* sp = self.wrapped.get() + return pyarrow_wrap_array(sp.value.dictionary) + + def as_py(self): + """ + Return this encoded value as a Python object. + """ + return self.value.as_py() if self.is_valid else None + + +cdef class RunEndEncodedScalar(Scalar): + """ + Concrete class for RunEndEncoded scalars. + """ + @property + def value(self): + """ + Return underlying value as a scalar. + """ + cdef CRunEndEncodedScalar* sp = self.wrapped.get() + return Scalar.wrap(sp.value) + + def as_py(self): + """ + Return underlying value as a Python object. + """ + return self.value.as_py() + + +cdef class UnionScalar(Scalar): + """ + Concrete class for Union scalars. + """ + + @property + def value(self): + """ + Return underlying value as a scalar. + """ + cdef CSparseUnionScalar* sp + cdef CDenseUnionScalar* dp + if self.type.id == _Type_SPARSE_UNION: + sp = self.wrapped.get() + return Scalar.wrap(sp.value[sp.child_id]) if sp.is_valid else None + else: + dp = self.wrapped.get() + return Scalar.wrap(dp.value) if dp.is_valid else None + + def as_py(self): + """ + Return underlying value as a Python object. + """ + value = self.value + return None if value is None else value.as_py() + + @property + def type_code(self): + """ + Return the union type code for this scalar. + """ + cdef CUnionScalar* sp = self.wrapped.get() + return sp.type_code + + +cdef class ExtensionScalar(Scalar): + """ + Concrete class for Extension scalars. + """ + + @property + def value(self): + """ + Return storage value as a scalar. + """ + cdef CExtensionScalar* sp = self.wrapped.get() + return Scalar.wrap(sp.value) if sp.is_valid else None + + def as_py(self): + """ + Return this scalar as a Python object. + """ + return None if self.value is None else self.value.as_py() + + @staticmethod + def from_storage(BaseExtensionType typ, value): + """ + Construct ExtensionScalar from type and storage value. + + Parameters + ---------- + typ : DataType + The extension type for the result scalar. + value : object + The storage value for the result scalar. + + Returns + ------- + ext_scalar : ExtensionScalar + """ + cdef: + shared_ptr[CExtensionScalar] sp_scalar + shared_ptr[CScalar] sp_storage + CExtensionScalar* ext_scalar + + if value is None: + storage = None + elif isinstance(value, Scalar): + if value.type != typ.storage_type: + raise TypeError("Incompatible storage type {0} " + "for extension type {1}" + .format(value.type, typ)) + storage = value + else: + storage = scalar(value, typ.storage_type) + + cdef c_bool is_valid = storage is not None and storage.is_valid + if is_valid: + sp_storage = pyarrow_unwrap_scalar(storage) + else: + sp_storage = MakeNullScalar(( typ.storage_type).sp_type) + sp_scalar = make_shared[CExtensionScalar](sp_storage, typ.sp_type, + is_valid) + with nogil: + check_status(sp_scalar.get().Validate()) + return pyarrow_wrap_scalar( sp_scalar) + + +class JsonScalar(ExtensionScalar): + """ + Concrete class for JSON extension scalar. + """ + + +class UuidScalar(ExtensionScalar): + """ + Concrete class for Uuid extension scalar. + """ + + def as_py(self): + return None if self.value is None else UUID(bytes=self.value.as_py()) + + +cdef class FixedShapeTensorScalar(ExtensionScalar): + """ + Concrete class for fixed shape tensor extension scalar. + """ + + def to_numpy(self): + """ + Convert fixed shape tensor scalar to a numpy.ndarray. + + The resulting ndarray's shape matches the permuted shape of the + fixed shape tensor scalar. + The conversion is zero-copy. + + Returns + ------- + numpy.ndarray + """ + return self.to_tensor().to_numpy() + + def to_tensor(self): + """ + Convert fixed shape tensor extension scalar to a pyarrow.Tensor, using shape + and strides derived from corresponding FixedShapeTensorType. + + The conversion is zero-copy. + + Returns + ------- + pyarrow.Tensor + Tensor represented stored in FixedShapeTensorScalar. + """ + cdef: + CFixedShapeTensorType* c_type = static_pointer_cast[CFixedShapeTensorType, CDataType]( + self.wrapped.get().type).get() + shared_ptr[CExtensionScalar] scalar = static_pointer_cast[CExtensionScalar, CScalar](self.wrapped) + shared_ptr[CTensor] ctensor + + with nogil: + ctensor = GetResultValue(c_type.MakeTensor(scalar)) + return pyarrow_wrap_tensor(ctensor) + + +cdef class OpaqueScalar(ExtensionScalar): + """ + Concrete class for opaque extension scalar. + """ + + +cdef class Bool8Scalar(ExtensionScalar): + """ + Concrete class for bool8 extension scalar. + """ + + def as_py(self): + """ + Return this scalar as a Python object. + """ + py_val = super().as_py() + return None if py_val is None else py_val != 0 + +cdef dict _scalar_classes = { + _Type_BOOL: BooleanScalar, + _Type_UINT8: UInt8Scalar, + _Type_UINT16: UInt16Scalar, + _Type_UINT32: UInt32Scalar, + _Type_UINT64: UInt64Scalar, + _Type_INT8: Int8Scalar, + _Type_INT16: Int16Scalar, + _Type_INT32: Int32Scalar, + _Type_INT64: Int64Scalar, + _Type_HALF_FLOAT: HalfFloatScalar, + _Type_FLOAT: FloatScalar, + _Type_DOUBLE: DoubleScalar, + _Type_DECIMAL32: Decimal32Scalar, + _Type_DECIMAL64: Decimal64Scalar, + _Type_DECIMAL128: Decimal128Scalar, + _Type_DECIMAL256: Decimal256Scalar, + _Type_DATE32: Date32Scalar, + _Type_DATE64: Date64Scalar, + _Type_TIME32: Time32Scalar, + _Type_TIME64: Time64Scalar, + _Type_TIMESTAMP: TimestampScalar, + _Type_DURATION: DurationScalar, + _Type_BINARY: BinaryScalar, + _Type_LARGE_BINARY: LargeBinaryScalar, + _Type_FIXED_SIZE_BINARY: FixedSizeBinaryScalar, + _Type_BINARY_VIEW: BinaryViewScalar, + _Type_STRING: StringScalar, + _Type_LARGE_STRING: LargeStringScalar, + _Type_STRING_VIEW: StringViewScalar, + _Type_LIST: ListScalar, + _Type_LARGE_LIST: LargeListScalar, + _Type_FIXED_SIZE_LIST: FixedSizeListScalar, + _Type_LIST_VIEW: ListViewScalar, + _Type_LARGE_LIST_VIEW: LargeListViewScalar, + _Type_STRUCT: StructScalar, + _Type_MAP: MapScalar, + _Type_DICTIONARY: DictionaryScalar, + _Type_RUN_END_ENCODED: RunEndEncodedScalar, + _Type_SPARSE_UNION: UnionScalar, + _Type_DENSE_UNION: UnionScalar, + _Type_INTERVAL_MONTH_DAY_NANO: MonthDayNanoIntervalScalar, + _Type_EXTENSION: ExtensionScalar, +} + + +cdef object get_scalar_class_from_type( + const shared_ptr[CDataType]& sp_data_type): + cdef CDataType* data_type = sp_data_type.get() + if data_type == NULL: + raise ValueError('Scalar data type was NULL') + + if data_type.id() == _Type_EXTENSION: + py_ext_data_type = pyarrow_wrap_data_type(sp_data_type) + return py_ext_data_type.__arrow_ext_scalar_class__() + else: + return _scalar_classes[data_type.id()] + + +def scalar(value, type=None, *, from_pandas=None, MemoryPool memory_pool=None): + """ + Create a pyarrow.Scalar instance from a Python object. + + Parameters + ---------- + value : Any + Python object coercible to arrow's type system. + type : pyarrow.DataType + Explicit type to attempt to coerce to, otherwise will be inferred from + the value. + from_pandas : bool, default None + Use pandas's semantics for inferring nulls from values in + ndarray-like data. Defaults to False if not passed explicitly by user, + or True if a pandas object is passed in. + memory_pool : pyarrow.MemoryPool, optional + If not passed, will allocate memory from the currently-set default + memory pool. + + Returns + ------- + scalar : pyarrow.Scalar + + Examples + -------- + >>> import pyarrow as pa + + >>> pa.scalar(42) + + + >>> pa.scalar("string") + + + >>> pa.scalar([1, 2]) + + + >>> pa.scalar([1, 2], type=pa.list_(pa.int16())) + + """ + cdef: + DataType ty + PyConversionOptions options + shared_ptr[CScalar] scalar + shared_ptr[CArray] array + shared_ptr[CChunkedArray] chunked + bint is_pandas_object = False + CMemoryPool* pool + + type = ensure_type(type, allow_none=True) + pool = maybe_unbox_memory_pool(memory_pool) + + extension_type = None + if type is not None and type.id == _Type_EXTENSION: + extension_type = type + type = type.storage_type + + if _is_array_like(value): + value = get_values(value, &is_pandas_object) + + options.size = 1 + + if type is not None: + ty = ensure_type(type) + options.type = ty.sp_type + + if from_pandas is None: + options.from_pandas = is_pandas_object + else: + options.from_pandas = from_pandas + + value = [value] + with nogil: + chunked = GetResultValue(ConvertPySequence(value, None, options, pool)) + + # get the first chunk + assert chunked.get().num_chunks() == 1 + array = chunked.get().chunk(0) + + # retrieve the scalar from the first position + scalar = GetResultValue(array.get().GetScalar(0)) + result = Scalar.wrap(scalar) + + if extension_type is not None: + result = ExtensionScalar.from_storage(extension_type, result) + return result diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/substrait.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/substrait.py new file mode 100644 index 0000000000000000000000000000000000000000..db2c3a96a19556103a9a57d551d53cedb5d53f9e --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/substrait.py @@ -0,0 +1,33 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +try: + from pyarrow._substrait import ( # noqa + BoundExpressions, + get_supported_functions, + run_query, + deserialize_expressions, + serialize_expressions, + deserialize_schema, + serialize_schema, + SubstraitSchema + ) +except ImportError as exc: + raise ImportError( + "The pyarrow installation is not built with support " + f"for 'substrait' ({str(exc)})" + ) from None diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/tensor.pxi b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/tensor.pxi new file mode 100644 index 0000000000000000000000000000000000000000..3e0c63c18fc98d5d0ba07b058ae52d6c46f544e0 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/tensor.pxi @@ -0,0 +1,1311 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Avoid name clash with `pa.struct` function +import struct as _struct + + +cdef class Tensor(_Weakrefable): + """ + A n-dimensional array a.k.a Tensor. + + Examples + -------- + >>> import pyarrow as pa + >>> import numpy as np + >>> x = np.array([[2, 2, 4], [4, 5, 100]], np.int32) + >>> pa.Tensor.from_numpy(x, dim_names=["dim1","dim2"]) + + type: int32 + shape: (2, 3) + strides: (12, 4) + """ + + def __init__(self): + raise TypeError("Do not call Tensor's constructor directly, use one " + "of the `pyarrow.Tensor.from_*` functions instead.") + + cdef void init(self, const shared_ptr[CTensor]& sp_tensor): + self.sp_tensor = sp_tensor + self.tp = sp_tensor.get() + self.type = pyarrow_wrap_data_type(self.tp.type()) + self._ssize_t_shape = self._make_shape_or_strides_buffer(self.shape) + self._ssize_t_strides = self._make_shape_or_strides_buffer(self.strides) + + def _make_shape_or_strides_buffer(self, values): + """ + Make a bytes object holding an array of `values` cast to `Py_ssize_t`. + """ + return _struct.pack(f"{len(values)}n", *values) + + def __repr__(self): + return """ +type: {0.type} +shape: {0.shape} +strides: {0.strides}""".format(self) + + @staticmethod + def from_numpy(obj, dim_names=None): + """ + Create a Tensor from a numpy array. + + Parameters + ---------- + obj : numpy.ndarray + The source numpy array + dim_names : list, optional + Names of each dimension of the Tensor. + + Examples + -------- + >>> import pyarrow as pa + >>> import numpy as np + >>> x = np.array([[2, 2, 4], [4, 5, 100]], np.int32) + >>> pa.Tensor.from_numpy(x, dim_names=["dim1","dim2"]) + + type: int32 + shape: (2, 3) + strides: (12, 4) + """ + cdef: + vector[c_string] c_dim_names + shared_ptr[CTensor] ctensor + + if dim_names is not None: + for x in dim_names: + c_dim_names.push_back(tobytes(x)) + + check_status(NdarrayToTensor(c_default_memory_pool(), obj, + c_dim_names, &ctensor)) + return pyarrow_wrap_tensor(ctensor) + + def to_numpy(self): + """ + Convert arrow::Tensor to numpy.ndarray with zero copy + + Examples + -------- + >>> import pyarrow as pa + >>> import numpy as np + >>> x = np.array([[2, 2, 4], [4, 5, 100]], np.int32) + >>> tensor = pa.Tensor.from_numpy(x, dim_names=["dim1","dim2"]) + >>> tensor.to_numpy() + array([[ 2, 2, 4], + [ 4, 5, 100]], dtype=int32) + """ + if np is None: + raise ImportError( + "Cannot return a numpy.ndarray if NumPy is not present") + cdef PyObject* out + + check_status(TensorToNdarray(self.sp_tensor, self, &out)) + return PyObject_to_object(out) + + def equals(self, Tensor other): + """ + Return true if the tensors contains exactly equal data. + + Parameters + ---------- + other : Tensor + The other tensor to compare for equality. + + Examples + -------- + >>> import pyarrow as pa + >>> import numpy as np + >>> x = np.array([[2, 2, 4], [4, 5, 100]], np.int32) + >>> tensor = pa.Tensor.from_numpy(x, dim_names=["dim1","dim2"]) + >>> y = np.array([[2, 2, 4], [4, 5, 10]], np.int32) + >>> tensor2 = pa.Tensor.from_numpy(y, dim_names=["a","b"]) + >>> tensor.equals(tensor) + True + >>> tensor.equals(tensor2) + False + """ + return self.tp.Equals(deref(other.tp)) + + def __eq__(self, other): + if isinstance(other, Tensor): + return self.equals(other) + else: + return NotImplemented + + def dim_name(self, i): + """ + Returns the name of the i-th tensor dimension. + + Parameters + ---------- + i : int + The physical index of the tensor dimension. + + Examples + -------- + >>> import pyarrow as pa + >>> import numpy as np + >>> x = np.array([[2, 2, 4], [4, 5, 100]], np.int32) + >>> tensor = pa.Tensor.from_numpy(x, dim_names=["dim1","dim2"]) + >>> tensor.dim_name(0) + 'dim1' + >>> tensor.dim_name(1) + 'dim2' + """ + return frombytes(self.tp.dim_name(i)) + + @property + def dim_names(self): + """ + Names of this tensor dimensions. + + Examples + -------- + >>> import pyarrow as pa + >>> import numpy as np + >>> x = np.array([[2, 2, 4], [4, 5, 100]], np.int32) + >>> tensor = pa.Tensor.from_numpy(x, dim_names=["dim1","dim2"]) + >>> tensor.dim_names + ['dim1', 'dim2'] + """ + return [frombytes(x) for x in tuple(self.tp.dim_names())] + + @property + def is_mutable(self): + """ + Is this tensor mutable or immutable. + + Examples + -------- + >>> import pyarrow as pa + >>> import numpy as np + >>> x = np.array([[2, 2, 4], [4, 5, 100]], np.int32) + >>> tensor = pa.Tensor.from_numpy(x, dim_names=["dim1","dim2"]) + >>> tensor.is_mutable + True + """ + return self.tp.is_mutable() + + @property + def is_contiguous(self): + """ + Is this tensor contiguous in memory. + + Examples + -------- + >>> import pyarrow as pa + >>> import numpy as np + >>> x = np.array([[2, 2, 4], [4, 5, 100]], np.int32) + >>> tensor = pa.Tensor.from_numpy(x, dim_names=["dim1","dim2"]) + >>> tensor.is_contiguous + True + """ + return self.tp.is_contiguous() + + @property + def ndim(self): + """ + The dimension (n) of this tensor. + + Examples + -------- + >>> import pyarrow as pa + >>> import numpy as np + >>> x = np.array([[2, 2, 4], [4, 5, 100]], np.int32) + >>> tensor = pa.Tensor.from_numpy(x, dim_names=["dim1","dim2"]) + >>> tensor.ndim + 2 + """ + return self.tp.ndim() + + @property + def size(self): + """ + The size of this tensor. + + Examples + -------- + >>> import pyarrow as pa + >>> import numpy as np + >>> x = np.array([[2, 2, 4], [4, 5, 100]], np.int32) + >>> tensor = pa.Tensor.from_numpy(x, dim_names=["dim1","dim2"]) + >>> tensor.size + 6 + """ + return self.tp.size() + + @property + def shape(self): + """ + The shape of this tensor. + + Examples + -------- + >>> import pyarrow as pa + >>> import numpy as np + >>> x = np.array([[2, 2, 4], [4, 5, 100]], np.int32) + >>> tensor = pa.Tensor.from_numpy(x, dim_names=["dim1","dim2"]) + >>> tensor.shape + (2, 3) + """ + # Cython knows how to convert a vector[T] to a Python list + return tuple(self.tp.shape()) + + @property + def strides(self): + """ + Strides of this tensor. + + Examples + -------- + >>> import pyarrow as pa + >>> import numpy as np + >>> x = np.array([[2, 2, 4], [4, 5, 100]], np.int32) + >>> tensor = pa.Tensor.from_numpy(x, dim_names=["dim1","dim2"]) + >>> tensor.strides + (12, 4) + """ + return tuple(self.tp.strides()) + + def __getbuffer__(self, cp.Py_buffer* buffer, int flags): + buffer.buf = self.tp.data().get().data() + pep3118_format = self.type.pep3118_format + if pep3118_format is None: + raise NotImplementedError("type %s not supported for buffer " + "protocol" % (self.type,)) + buffer.format = pep3118_format + buffer.itemsize = self.type.bit_width // 8 + buffer.internal = NULL + buffer.len = self.tp.size() * buffer.itemsize + buffer.ndim = self.tp.ndim() + buffer.obj = self + if self.tp.is_mutable(): + buffer.readonly = 0 + else: + buffer.readonly = 1 + buffer.shape = cp.PyBytes_AsString(self._ssize_t_shape) + buffer.strides = cp.PyBytes_AsString(self._ssize_t_strides) + buffer.suboffsets = NULL + + +ctypedef CSparseCOOIndex* _CSparseCOOIndexPtr + + +cdef class SparseCOOTensor(_Weakrefable): + """ + A sparse COO tensor. + """ + + def __init__(self): + raise TypeError("Do not call SparseCOOTensor's constructor directly, " + "use one of the `pyarrow.SparseCOOTensor.from_*` " + "functions instead.") + + cdef void init(self, const shared_ptr[CSparseCOOTensor]& sp_sparse_tensor): + self.sp_sparse_tensor = sp_sparse_tensor + self.stp = sp_sparse_tensor.get() + self.type = pyarrow_wrap_data_type(self.stp.type()) + + def __repr__(self): + return """ +type: {0.type} +shape: {0.shape}""".format(self) + + @classmethod + def from_dense_numpy(cls, obj, dim_names=None): + """ + Convert numpy.ndarray to arrow::SparseCOOTensor + + Parameters + ---------- + obj : numpy.ndarray + Data used to populate the rows. + dim_names : list[str], optional + Names of the dimensions. + + Returns + ------- + pyarrow.SparseCOOTensor + """ + return cls.from_tensor(Tensor.from_numpy(obj, dim_names=dim_names)) + + @staticmethod + def from_numpy(data, coords, shape, dim_names=None): + """ + Create arrow::SparseCOOTensor from numpy.ndarrays + + Parameters + ---------- + data : numpy.ndarray + Data used to populate the rows. + coords : numpy.ndarray + Coordinates of the data. + shape : tuple + Shape of the tensor. + dim_names : list, optional + Names of the dimensions. + """ + cdef shared_ptr[CSparseCOOTensor] csparse_tensor + cdef vector[int64_t] c_shape + cdef vector[c_string] c_dim_names + + for x in shape: + c_shape.push_back(x) + if dim_names is not None: + for x in dim_names: + c_dim_names.push_back(tobytes(x)) + + # Enforce precondition for SparseCOOTensor indices + coords = np.require(coords, dtype='i8', requirements='C') + if coords.ndim != 2: + raise ValueError("Expected 2-dimensional array for " + "SparseCOOTensor indices") + + check_status(NdarraysToSparseCOOTensor(c_default_memory_pool(), + data, coords, c_shape, + c_dim_names, &csparse_tensor)) + return pyarrow_wrap_sparse_coo_tensor(csparse_tensor) + + @staticmethod + def from_scipy(obj, dim_names=None): + """ + Convert scipy.sparse.coo_matrix to arrow::SparseCOOTensor + + Parameters + ---------- + obj : scipy.sparse.csr_matrix + The scipy matrix that should be converted. + dim_names : list, optional + Names of the dimensions. + """ + import scipy.sparse + if not isinstance(obj, scipy.sparse.coo_matrix): + raise TypeError( + "Expected scipy.sparse.coo_matrix, got {}".format(type(obj))) + + cdef shared_ptr[CSparseCOOTensor] csparse_tensor + cdef vector[int64_t] c_shape + cdef vector[c_string] c_dim_names + + for x in obj.shape: + c_shape.push_back(x) + if dim_names is not None: + for x in dim_names: + c_dim_names.push_back(tobytes(x)) + + row = obj.row + col = obj.col + + # When SciPy's coo_matrix has canonical format, its indices matrix is + # sorted in column-major order. As Arrow's SparseCOOIndex is sorted + # in row-major order if it is canonical, we must sort indices matrix + # into row-major order to keep its canonicalness, here. + if obj.has_canonical_format: + order = np.lexsort((col, row)) # sort in row-major order + row = row[order] + col = col[order] + coords = np.vstack([row, col]).T + coords = np.require(coords, dtype='i8', requirements='C') + + check_status(NdarraysToSparseCOOTensor(c_default_memory_pool(), + obj.data, coords, c_shape, + c_dim_names, &csparse_tensor)) + return pyarrow_wrap_sparse_coo_tensor(csparse_tensor) + + @staticmethod + def from_pydata_sparse(obj, dim_names=None): + """ + Convert pydata/sparse.COO to arrow::SparseCOOTensor. + + Parameters + ---------- + obj : pydata.sparse.COO + The sparse multidimensional array that should be converted. + dim_names : list, optional + Names of the dimensions. + """ + import sparse + if not isinstance(obj, sparse.COO): + raise TypeError( + "Expected sparse.COO, got {}".format(type(obj))) + + cdef shared_ptr[CSparseCOOTensor] csparse_tensor + cdef vector[int64_t] c_shape + cdef vector[c_string] c_dim_names + + for x in obj.shape: + c_shape.push_back(x) + if dim_names is not None: + for x in dim_names: + c_dim_names.push_back(tobytes(x)) + + coords = np.require(obj.coords.T, dtype='i8', requirements='C') + + check_status(NdarraysToSparseCOOTensor(c_default_memory_pool(), + obj.data, coords, c_shape, + c_dim_names, &csparse_tensor)) + return pyarrow_wrap_sparse_coo_tensor(csparse_tensor) + + @staticmethod + def from_tensor(obj): + """ + Convert arrow::Tensor to arrow::SparseCOOTensor. + + Parameters + ---------- + obj : Tensor + The tensor that should be converted. + """ + cdef shared_ptr[CSparseCOOTensor] csparse_tensor + cdef shared_ptr[CTensor] ctensor = pyarrow_unwrap_tensor(obj) + + with nogil: + check_status(TensorToSparseCOOTensor(ctensor, &csparse_tensor)) + + return pyarrow_wrap_sparse_coo_tensor(csparse_tensor) + + def to_numpy(self): + """ + Convert arrow::SparseCOOTensor to numpy.ndarrays with zero copy. + """ + if np is None: + raise ImportError( + "Cannot return a numpy.ndarray if NumPy is not present") + cdef PyObject* out_data + cdef PyObject* out_coords + + check_status(SparseCOOTensorToNdarray(self.sp_sparse_tensor, self, + &out_data, &out_coords)) + return PyObject_to_object(out_data), PyObject_to_object(out_coords) + + def to_scipy(self): + """ + Convert arrow::SparseCOOTensor to scipy.sparse.coo_matrix. + """ + from scipy.sparse import coo_matrix + cdef PyObject* out_data + cdef PyObject* out_coords + + check_status(SparseCOOTensorToNdarray(self.sp_sparse_tensor, self, + &out_data, &out_coords)) + data = PyObject_to_object(out_data) + coords = PyObject_to_object(out_coords) + row, col = coords[:, 0], coords[:, 1] + result = coo_matrix((data[:, 0], (row, col)), shape=self.shape) + + # As the description in from_scipy above, we sorted indices matrix + # in row-major order if SciPy's coo_matrix has canonical format. + # So, we must call sum_duplicates() to make the result coo_matrix + # has canonical format. + if self.has_canonical_format: + result.sum_duplicates() + return result + + def to_pydata_sparse(self): + """ + Convert arrow::SparseCOOTensor to pydata/sparse.COO. + """ + from sparse import COO + cdef PyObject* out_data + cdef PyObject* out_coords + + check_status(SparseCOOTensorToNdarray(self.sp_sparse_tensor, self, + &out_data, &out_coords)) + data = PyObject_to_object(out_data) + coords = PyObject_to_object(out_coords) + result = COO(data=data[:, 0], coords=coords.T, shape=self.shape) + return result + + def to_tensor(self): + """ + Convert arrow::SparseCOOTensor to arrow::Tensor. + """ + + cdef shared_ptr[CTensor] ctensor + with nogil: + ctensor = GetResultValue(self.stp.ToTensor()) + + return pyarrow_wrap_tensor(ctensor) + + def equals(self, SparseCOOTensor other): + """ + Return true if sparse tensors contains exactly equal data. + + Parameters + ---------- + other : SparseCOOTensor + The other tensor to compare for equality. + """ + return self.stp.Equals(deref(other.stp)) + + def __eq__(self, other): + if isinstance(other, SparseCOOTensor): + return self.equals(other) + else: + return NotImplemented + + @property + def is_mutable(self): + return self.stp.is_mutable() + + @property + def ndim(self): + return self.stp.ndim() + + @property + def shape(self): + # Cython knows how to convert a vector[T] to a Python list + return tuple(self.stp.shape()) + + @property + def size(self): + return self.stp.size() + + def dim_name(self, i): + """ + Returns the name of the i-th tensor dimension. + + Parameters + ---------- + i : int + The physical index of the tensor dimension. + + Returns + ------- + str + """ + return frombytes(self.stp.dim_name(i)) + + @property + def dim_names(self): + names_tuple = tuple(self.stp.dim_names()) + return tuple(frombytes(x) for x in names_tuple) + + @property + def non_zero_length(self): + return self.stp.non_zero_length() + + @property + def has_canonical_format(self): + cdef: + _CSparseCOOIndexPtr csi + + csi = <_CSparseCOOIndexPtr>(self.stp.sparse_index().get()) + if csi != nullptr: + return csi.is_canonical() + return True + +cdef class SparseCSRMatrix(_Weakrefable): + """ + A sparse CSR matrix. + """ + + def __init__(self): + raise TypeError("Do not call SparseCSRMatrix's constructor directly, " + "use one of the `pyarrow.SparseCSRMatrix.from_*` " + "functions instead.") + + cdef void init(self, const shared_ptr[CSparseCSRMatrix]& sp_sparse_tensor): + self.sp_sparse_tensor = sp_sparse_tensor + self.stp = sp_sparse_tensor.get() + self.type = pyarrow_wrap_data_type(self.stp.type()) + + def __repr__(self): + return """ +type: {0.type} +shape: {0.shape}""".format(self) + + @classmethod + def from_dense_numpy(cls, obj, dim_names=None): + """ + Convert numpy.ndarray to arrow::SparseCSRMatrix + + Parameters + ---------- + obj : numpy.ndarray + The dense numpy array that should be converted. + dim_names : list, optional + The names of the dimensions. + + Returns + ------- + pyarrow.SparseCSRMatrix + """ + return cls.from_tensor(Tensor.from_numpy(obj, dim_names=dim_names)) + + @staticmethod + def from_numpy(data, indptr, indices, shape, dim_names=None): + """ + Create arrow::SparseCSRMatrix from numpy.ndarrays. + + Parameters + ---------- + data : numpy.ndarray + Data used to populate the sparse matrix. + indptr : numpy.ndarray + Range of the rows, + The i-th row spans from `indptr[i]` to `indptr[i+1]` in the data. + indices : numpy.ndarray + Column indices of the corresponding non-zero values. + shape : tuple + Shape of the matrix. + dim_names : list, optional + Names of the dimensions. + """ + cdef shared_ptr[CSparseCSRMatrix] csparse_tensor + cdef vector[int64_t] c_shape + cdef vector[c_string] c_dim_names + + for x in shape: + c_shape.push_back(x) + if dim_names is not None: + for x in dim_names: + c_dim_names.push_back(tobytes(x)) + + # Enforce precondition for SparseCSRMatrix indices + indptr = np.require(indptr, dtype='i8') + indices = np.require(indices, dtype='i8') + if indptr.ndim != 1: + raise ValueError("Expected 1-dimensional array for " + "SparseCSRMatrix indptr") + if indices.ndim != 1: + raise ValueError("Expected 1-dimensional array for " + "SparseCSRMatrix indices") + + check_status(NdarraysToSparseCSRMatrix(c_default_memory_pool(), + data, indptr, indices, c_shape, + c_dim_names, &csparse_tensor)) + return pyarrow_wrap_sparse_csr_matrix(csparse_tensor) + + @staticmethod + def from_scipy(obj, dim_names=None): + """ + Convert scipy.sparse.csr_matrix to arrow::SparseCSRMatrix. + + Parameters + ---------- + obj : scipy.sparse.csr_matrix + The scipy matrix that should be converted. + dim_names : list, optional + Names of the dimensions. + """ + import scipy.sparse + if not isinstance(obj, scipy.sparse.csr_matrix): + raise TypeError( + "Expected scipy.sparse.csr_matrix, got {}".format(type(obj))) + + cdef shared_ptr[CSparseCSRMatrix] csparse_tensor + cdef vector[int64_t] c_shape + cdef vector[c_string] c_dim_names + + for x in obj.shape: + c_shape.push_back(x) + if dim_names is not None: + for x in dim_names: + c_dim_names.push_back(tobytes(x)) + + # Enforce precondition for CSparseCSRMatrix indices + indptr = np.require(obj.indptr, dtype='i8') + indices = np.require(obj.indices, dtype='i8') + + check_status(NdarraysToSparseCSRMatrix(c_default_memory_pool(), + obj.data, indptr, indices, + c_shape, c_dim_names, + &csparse_tensor)) + return pyarrow_wrap_sparse_csr_matrix(csparse_tensor) + + @staticmethod + def from_tensor(obj): + """ + Convert arrow::Tensor to arrow::SparseCSRMatrix. + + Parameters + ---------- + obj : Tensor + The dense tensor that should be converted. + """ + cdef shared_ptr[CSparseCSRMatrix] csparse_tensor + cdef shared_ptr[CTensor] ctensor = pyarrow_unwrap_tensor(obj) + + with nogil: + check_status(TensorToSparseCSRMatrix(ctensor, &csparse_tensor)) + + return pyarrow_wrap_sparse_csr_matrix(csparse_tensor) + + def to_numpy(self): + """ + Convert arrow::SparseCSRMatrix to numpy.ndarrays with zero copy. + """ + if np is None: + raise ImportError( + "Cannot return a numpy.ndarray if NumPy is not present") + cdef PyObject* out_data + cdef PyObject* out_indptr + cdef PyObject* out_indices + + check_status(SparseCSRMatrixToNdarray(self.sp_sparse_tensor, self, + &out_data, &out_indptr, + &out_indices)) + return (PyObject_to_object(out_data), PyObject_to_object(out_indptr), + PyObject_to_object(out_indices)) + + def to_scipy(self): + """ + Convert arrow::SparseCSRMatrix to scipy.sparse.csr_matrix. + """ + from scipy.sparse import csr_matrix + cdef PyObject* out_data + cdef PyObject* out_indptr + cdef PyObject* out_indices + + check_status(SparseCSRMatrixToNdarray(self.sp_sparse_tensor, self, + &out_data, &out_indptr, + &out_indices)) + + data = PyObject_to_object(out_data) + indptr = PyObject_to_object(out_indptr) + indices = PyObject_to_object(out_indices) + result = csr_matrix((data[:, 0], indices, indptr), shape=self.shape) + return result + + def to_tensor(self): + """ + Convert arrow::SparseCSRMatrix to arrow::Tensor. + """ + cdef shared_ptr[CTensor] ctensor + with nogil: + ctensor = GetResultValue(self.stp.ToTensor()) + + return pyarrow_wrap_tensor(ctensor) + + def equals(self, SparseCSRMatrix other): + """ + Return true if sparse tensors contains exactly equal data. + + Parameters + ---------- + other : SparseCSRMatrix + The other tensor to compare for equality. + """ + return self.stp.Equals(deref(other.stp)) + + def __eq__(self, other): + if isinstance(other, SparseCSRMatrix): + return self.equals(other) + else: + return NotImplemented + + @property + def is_mutable(self): + return self.stp.is_mutable() + + @property + def ndim(self): + return self.stp.ndim() + + @property + def shape(self): + # Cython knows how to convert a vector[T] to a Python list + return tuple(self.stp.shape()) + + @property + def size(self): + return self.stp.size() + + def dim_name(self, i): + """ + Returns the name of the i-th tensor dimension. + + Parameters + ---------- + i : int + The physical index of the tensor dimension. + + Returns + ------- + str + """ + return frombytes(self.stp.dim_name(i)) + + @property + def dim_names(self): + names_tuple = tuple(self.stp.dim_names()) + return tuple(frombytes(x) for x in names_tuple) + + @property + def non_zero_length(self): + return self.stp.non_zero_length() + +cdef class SparseCSCMatrix(_Weakrefable): + """ + A sparse CSC matrix. + """ + + def __init__(self): + raise TypeError("Do not call SparseCSCMatrix's constructor directly, " + "use one of the `pyarrow.SparseCSCMatrix.from_*` " + "functions instead.") + + cdef void init(self, const shared_ptr[CSparseCSCMatrix]& sp_sparse_tensor): + self.sp_sparse_tensor = sp_sparse_tensor + self.stp = sp_sparse_tensor.get() + self.type = pyarrow_wrap_data_type(self.stp.type()) + + def __repr__(self): + return """ +type: {0.type} +shape: {0.shape}""".format(self) + + @classmethod + def from_dense_numpy(cls, obj, dim_names=None): + """ + Convert numpy.ndarray to arrow::SparseCSCMatrix + + Parameters + ---------- + obj : numpy.ndarray + Data used to populate the rows. + dim_names : list[str], optional + Names of the dimensions. + + Returns + ------- + pyarrow.SparseCSCMatrix + """ + return cls.from_tensor(Tensor.from_numpy(obj, dim_names=dim_names)) + + @staticmethod + def from_numpy(data, indptr, indices, shape, dim_names=None): + """ + Create arrow::SparseCSCMatrix from numpy.ndarrays + + Parameters + ---------- + data : numpy.ndarray + Data used to populate the sparse matrix. + indptr : numpy.ndarray + Range of the rows, + The i-th row spans from `indptr[i]` to `indptr[i+1]` in the data. + indices : numpy.ndarray + Column indices of the corresponding non-zero values. + shape : tuple + Shape of the matrix. + dim_names : list, optional + Names of the dimensions. + """ + cdef shared_ptr[CSparseCSCMatrix] csparse_tensor + cdef vector[int64_t] c_shape + cdef vector[c_string] c_dim_names + + for x in shape: + c_shape.push_back(x) + if dim_names is not None: + for x in dim_names: + c_dim_names.push_back(tobytes(x)) + + # Enforce precondition for SparseCSCMatrix indices + indptr = np.require(indptr, dtype='i8') + indices = np.require(indices, dtype='i8') + if indptr.ndim != 1: + raise ValueError("Expected 1-dimensional array for " + "SparseCSCMatrix indptr") + if indices.ndim != 1: + raise ValueError("Expected 1-dimensional array for " + "SparseCSCMatrix indices") + + check_status(NdarraysToSparseCSCMatrix(c_default_memory_pool(), + data, indptr, indices, c_shape, + c_dim_names, &csparse_tensor)) + return pyarrow_wrap_sparse_csc_matrix(csparse_tensor) + + @staticmethod + def from_scipy(obj, dim_names=None): + """ + Convert scipy.sparse.csc_matrix to arrow::SparseCSCMatrix + + Parameters + ---------- + obj : scipy.sparse.csc_matrix + The scipy matrix that should be converted. + dim_names : list, optional + Names of the dimensions. + """ + import scipy.sparse + if not isinstance(obj, scipy.sparse.csc_matrix): + raise TypeError( + "Expected scipy.sparse.csc_matrix, got {}".format(type(obj))) + + cdef shared_ptr[CSparseCSCMatrix] csparse_tensor + cdef vector[int64_t] c_shape + cdef vector[c_string] c_dim_names + + for x in obj.shape: + c_shape.push_back(x) + if dim_names is not None: + for x in dim_names: + c_dim_names.push_back(tobytes(x)) + + # Enforce precondition for CSparseCSCMatrix indices + indptr = np.require(obj.indptr, dtype='i8') + indices = np.require(obj.indices, dtype='i8') + + check_status(NdarraysToSparseCSCMatrix(c_default_memory_pool(), + obj.data, indptr, indices, + c_shape, c_dim_names, + &csparse_tensor)) + return pyarrow_wrap_sparse_csc_matrix(csparse_tensor) + + @staticmethod + def from_tensor(obj): + """ + Convert arrow::Tensor to arrow::SparseCSCMatrix + + Parameters + ---------- + obj : Tensor + The dense tensor that should be converted. + """ + cdef shared_ptr[CSparseCSCMatrix] csparse_tensor + cdef shared_ptr[CTensor] ctensor = pyarrow_unwrap_tensor(obj) + + with nogil: + check_status(TensorToSparseCSCMatrix(ctensor, &csparse_tensor)) + + return pyarrow_wrap_sparse_csc_matrix(csparse_tensor) + + def to_numpy(self): + """ + Convert arrow::SparseCSCMatrix to numpy.ndarrays with zero copy + """ + if np is None: + raise ImportError( + "Cannot return a numpy.ndarray if NumPy is not present") + cdef PyObject* out_data + cdef PyObject* out_indptr + cdef PyObject* out_indices + + check_status(SparseCSCMatrixToNdarray(self.sp_sparse_tensor, self, + &out_data, &out_indptr, + &out_indices)) + return (PyObject_to_object(out_data), PyObject_to_object(out_indptr), + PyObject_to_object(out_indices)) + + def to_scipy(self): + """ + Convert arrow::SparseCSCMatrix to scipy.sparse.csc_matrix + """ + from scipy.sparse import csc_matrix + cdef PyObject* out_data + cdef PyObject* out_indptr + cdef PyObject* out_indices + + check_status(SparseCSCMatrixToNdarray(self.sp_sparse_tensor, self, + &out_data, &out_indptr, + &out_indices)) + + data = PyObject_to_object(out_data) + indptr = PyObject_to_object(out_indptr) + indices = PyObject_to_object(out_indices) + result = csc_matrix((data[:, 0], indices, indptr), shape=self.shape) + return result + + def to_tensor(self): + """ + Convert arrow::SparseCSCMatrix to arrow::Tensor + """ + + cdef shared_ptr[CTensor] ctensor + with nogil: + ctensor = GetResultValue(self.stp.ToTensor()) + + return pyarrow_wrap_tensor(ctensor) + + def equals(self, SparseCSCMatrix other): + """ + Return true if sparse tensors contains exactly equal data + + Parameters + ---------- + other : SparseCSCMatrix + The other tensor to compare for equality. + """ + return self.stp.Equals(deref(other.stp)) + + def __eq__(self, other): + if isinstance(other, SparseCSCMatrix): + return self.equals(other) + else: + return NotImplemented + + @property + def is_mutable(self): + return self.stp.is_mutable() + + @property + def ndim(self): + return self.stp.ndim() + + @property + def shape(self): + # Cython knows how to convert a vector[T] to a Python list + return tuple(self.stp.shape()) + + @property + def size(self): + return self.stp.size() + + def dim_name(self, i): + """ + Returns the name of the i-th tensor dimension. + + Parameters + ---------- + i : int + The physical index of the tensor dimension. + + Returns + ------- + str + """ + return frombytes(self.stp.dim_name(i)) + + @property + def dim_names(self): + names_tuple = tuple(self.stp.dim_names()) + return tuple(frombytes(x) for x in names_tuple) + + @property + def non_zero_length(self): + return self.stp.non_zero_length() + + +cdef class SparseCSFTensor(_Weakrefable): + """ + A sparse CSF tensor. + + CSF is a generalization of compressed sparse row (CSR) index. + + CSF index recursively compresses each dimension of a tensor into a set + of prefix trees. Each path from a root to leaf forms one tensor + non-zero index. CSF is implemented with two arrays of buffers and one + arrays of integers. + """ + + def __init__(self): + raise TypeError("Do not call SparseCSFTensor's constructor directly, " + "use one of the `pyarrow.SparseCSFTensor.from_*` " + "functions instead.") + + cdef void init(self, const shared_ptr[CSparseCSFTensor]& sp_sparse_tensor): + self.sp_sparse_tensor = sp_sparse_tensor + self.stp = sp_sparse_tensor.get() + self.type = pyarrow_wrap_data_type(self.stp.type()) + + def __repr__(self): + return """ +type: {0.type} +shape: {0.shape}""".format(self) + + @classmethod + def from_dense_numpy(cls, obj, dim_names=None): + """ + Convert numpy.ndarray to arrow::SparseCSFTensor + + Parameters + ---------- + obj : numpy.ndarray + Data used to populate the rows. + dim_names : list[str], optional + Names of the dimensions. + + Returns + ------- + pyarrow.SparseCSFTensor + """ + return cls.from_tensor(Tensor.from_numpy(obj, dim_names=dim_names)) + + @staticmethod + def from_numpy(data, indptr, indices, shape, axis_order=None, + dim_names=None): + """ + Create arrow::SparseCSFTensor from numpy.ndarrays + + Parameters + ---------- + data : numpy.ndarray + Data used to populate the sparse tensor. + indptr : numpy.ndarray + The sparsity structure. + Each two consecutive dimensions in a tensor correspond to + a buffer in indices. + A pair of consecutive values at `indptr[dim][i]` + `indptr[dim][i + 1]` signify a range of nodes in + `indices[dim + 1]` who are children of `indices[dim][i]` node. + indices : numpy.ndarray + Stores values of nodes. + Each tensor dimension corresponds to a buffer in indptr. + shape : tuple + Shape of the matrix. + axis_order : list, optional + the sequence in which dimensions were traversed to + produce the prefix tree. + dim_names : list, optional + Names of the dimensions. + """ + cdef shared_ptr[CSparseCSFTensor] csparse_tensor + cdef vector[int64_t] c_axis_order + cdef vector[int64_t] c_shape + cdef vector[c_string] c_dim_names + + for x in shape: + c_shape.push_back(x) + if not axis_order: + axis_order = np.argsort(shape) + for x in axis_order: + c_axis_order.push_back(x) + if dim_names is not None: + for x in dim_names: + c_dim_names.push_back(tobytes(x)) + + # Enforce preconditions for SparseCSFTensor indices + if not (isinstance(indptr, (list, tuple)) and + isinstance(indices, (list, tuple))): + raise TypeError("Expected list or tuple, got {}, {}" + .format(type(indptr), type(indices))) + if len(indptr) != len(shape) - 1: + raise ValueError("Expected list of {ndim} np.arrays for " + "SparseCSFTensor.indptr".format(ndim=len(shape))) + if len(indices) != len(shape): + raise ValueError("Expected list of {ndim} np.arrays for " + "SparseCSFTensor.indices".format(ndim=len(shape))) + if any([x.ndim != 1 for x in indptr]): + raise ValueError("Expected a list of 1-dimensional arrays for " + "SparseCSFTensor.indptr") + if any([x.ndim != 1 for x in indices]): + raise ValueError("Expected a list of 1-dimensional arrays for " + "SparseCSFTensor.indices") + indptr = [np.require(arr, dtype='i8') for arr in indptr] + indices = [np.require(arr, dtype='i8') for arr in indices] + + check_status(NdarraysToSparseCSFTensor(c_default_memory_pool(), data, + indptr, indices, c_shape, + c_axis_order, c_dim_names, + &csparse_tensor)) + return pyarrow_wrap_sparse_csf_tensor(csparse_tensor) + + @staticmethod + def from_tensor(obj): + """ + Convert arrow::Tensor to arrow::SparseCSFTensor + + Parameters + ---------- + obj : Tensor + The dense tensor that should be converted. + """ + cdef shared_ptr[CSparseCSFTensor] csparse_tensor + cdef shared_ptr[CTensor] ctensor = pyarrow_unwrap_tensor(obj) + + with nogil: + check_status(TensorToSparseCSFTensor(ctensor, &csparse_tensor)) + + return pyarrow_wrap_sparse_csf_tensor(csparse_tensor) + + def to_numpy(self): + """ + Convert arrow::SparseCSFTensor to numpy.ndarrays with zero copy + """ + if np is None: + raise ImportError( + "Cannot return a numpy.ndarray if NumPy is not present") + cdef PyObject* out_data + cdef PyObject* out_indptr + cdef PyObject* out_indices + + check_status(SparseCSFTensorToNdarray(self.sp_sparse_tensor, self, + &out_data, &out_indptr, + &out_indices)) + return (PyObject_to_object(out_data), PyObject_to_object(out_indptr), + PyObject_to_object(out_indices)) + + def to_tensor(self): + """ + Convert arrow::SparseCSFTensor to arrow::Tensor + """ + + cdef shared_ptr[CTensor] ctensor + with nogil: + ctensor = GetResultValue(self.stp.ToTensor()) + + return pyarrow_wrap_tensor(ctensor) + + def equals(self, SparseCSFTensor other): + """ + Return true if sparse tensors contains exactly equal data + + Parameters + ---------- + other : SparseCSFTensor + The other tensor to compare for equality. + """ + return self.stp.Equals(deref(other.stp)) + + def __eq__(self, other): + if isinstance(other, SparseCSFTensor): + return self.equals(other) + else: + return NotImplemented + + @property + def is_mutable(self): + return self.stp.is_mutable() + + @property + def ndim(self): + return self.stp.ndim() + + @property + def shape(self): + # Cython knows how to convert a vector[T] to a Python list + return tuple(self.stp.shape()) + + @property + def size(self): + return self.stp.size() + + def dim_name(self, i): + """ + Returns the name of the i-th tensor dimension. + + Parameters + ---------- + i : int + The physical index of the tensor dimension. + + Returns + ------- + str + """ + return frombytes(self.stp.dim_name(i)) + + @property + def dim_names(self): + names_tuple = tuple(self.stp.dim_names()) + return tuple(frombytes(x) for x in names_tuple) + + @property + def non_zero_length(self): + return self.stp.non_zero_length() diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/types.pxi b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/types.pxi new file mode 100644 index 0000000000000000000000000000000000000000..3caf068a4c9b1dc8e94e9ce8019fd3d5e5c66c7e --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/types.pxi @@ -0,0 +1,6202 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from cpython.pycapsule cimport ( + PyCapsule_CheckExact, + PyCapsule_GetPointer, + PyCapsule_GetName, + PyCapsule_New, + PyCapsule_IsValid +) + +import atexit +from collections.abc import Mapping +import pickle +import re +import sys +import warnings +from cython import sizeof + +# These are imprecise because the type (in pandas 0.x) depends on the presence +# of nulls +cdef dict _pandas_type_map = {} + + +def _get_pandas_type_map(): + global _pandas_type_map + if not _pandas_type_map: + _pandas_type_map.update({ + _Type_NA: np.object_, # NaNs + _Type_BOOL: np.bool_, + _Type_INT8: np.int8, + _Type_INT16: np.int16, + _Type_INT32: np.int32, + _Type_INT64: np.int64, + _Type_UINT8: np.uint8, + _Type_UINT16: np.uint16, + _Type_UINT32: np.uint32, + _Type_UINT64: np.uint64, + _Type_HALF_FLOAT: np.float16, + _Type_FLOAT: np.float32, + _Type_DOUBLE: np.float64, + # Pandas does not support [D]ay, so default to [ms] for date32 + _Type_DATE32: np.dtype('datetime64[ms]'), + _Type_DATE64: np.dtype('datetime64[ms]'), + _Type_TIMESTAMP: { + 's': np.dtype('datetime64[s]'), + 'ms': np.dtype('datetime64[ms]'), + 'us': np.dtype('datetime64[us]'), + 'ns': np.dtype('datetime64[ns]'), + }, + _Type_DURATION: { + 's': np.dtype('timedelta64[s]'), + 'ms': np.dtype('timedelta64[ms]'), + 'us': np.dtype('timedelta64[us]'), + 'ns': np.dtype('timedelta64[ns]'), + }, + _Type_BINARY: np.object_, + _Type_FIXED_SIZE_BINARY: np.object_, + _Type_STRING: np.object_, + _Type_LIST: np.object_, + _Type_MAP: np.object_, + _Type_DECIMAL32: np.object_, + _Type_DECIMAL64: np.object_, + _Type_DECIMAL128: np.object_, + _Type_DECIMAL256: np.object_, + }) + return _pandas_type_map + + +cdef dict _pep3118_type_map = { + _Type_INT8: b'b', + _Type_INT16: b'h', + _Type_INT32: b'i', + _Type_INT64: b'q', + _Type_UINT8: b'B', + _Type_UINT16: b'H', + _Type_UINT32: b'I', + _Type_UINT64: b'Q', + _Type_HALF_FLOAT: b'e', + _Type_FLOAT: b'f', + _Type_DOUBLE: b'd', +} + + +cdef bytes _datatype_to_pep3118(CDataType* type): + """ + Construct a PEP 3118 format string describing the given datatype. + None is returned for unsupported types. + """ + try: + char = _pep3118_type_map[type.id()] + except KeyError: + return None + else: + if char in b'bBhHiIqQ': + # Use "standard" int widths, not native + return b'=' + char + else: + return char + + +cdef void* _as_c_pointer(v, allow_null=False) except *: + """ + Convert a Python object to a raw C pointer. + + Used mainly for the C data interface. + Integers are accepted as well as capsule objects with a NULL name. + (the latter for compatibility with raw pointers exported by reticulate) + """ + cdef void* c_ptr + cdef const char* capsule_name + if isinstance(v, int): + c_ptr = v + elif isinstance(v, float): + warnings.warn( + "Passing a pointer value as a float is unsafe and only " + "supported for compatibility with older versions of the R " + "Arrow library", UserWarning, stacklevel=2) + c_ptr = v + elif PyCapsule_CheckExact(v): + # An R external pointer was how the R bindings passed pointer values to + # Python from versions 7 to 15 (inclusive); however, the reticulate 1.35.0 + # update changed the name of the capsule from NULL to "r_extptr". + # Newer versions of the R package pass a Python integer; however, this + # workaround ensures that old versions of the R package continue to work + # with newer versions of pyarrow. + capsule_name = PyCapsule_GetName(v) + if capsule_name == NULL or capsule_name == b"r_extptr": + c_ptr = PyCapsule_GetPointer(v, capsule_name) + else: + capsule_name_str = capsule_name.decode() + raise ValueError( + f"Can't convert PyCapsule with name '{capsule_name_str}' to pointer address" + ) + else: + raise TypeError(f"Expected a pointer value, got {type(v)!r}") + if not allow_null and c_ptr == NULL: + raise ValueError(f"Null pointer (value before cast = {v!r})") + return c_ptr + + +def _is_primitive(Type type): + # This is simply a redirect, the official API is in pyarrow.types. + return is_primitive(type) + + +def _get_pandas_type(arrow_type, coerce_to_ns=False): + cdef Type type_id = arrow_type.id + cdef dict pandas_type_map = _get_pandas_type_map() + if type_id not in pandas_type_map: + return None + if coerce_to_ns: + # ARROW-3789: Coerce date/timestamp types to datetime64[ns] + if type_id == _Type_DURATION: + return np.dtype('timedelta64[ns]') + return np.dtype('datetime64[ns]') + pandas_type = pandas_type_map[type_id] + if isinstance(pandas_type, dict): + unit = getattr(arrow_type, 'unit', None) + pandas_type = pandas_type.get(unit, None) + return pandas_type + + +def _get_pandas_tz_type(arrow_type, coerce_to_ns=False): + from pyarrow.pandas_compat import make_datetimetz + unit = 'ns' if coerce_to_ns else arrow_type.unit + return make_datetimetz(unit, arrow_type.tz) + + +def _to_pandas_dtype(arrow_type, options=None): + coerce_to_ns = (options and options.get('coerce_temporal_nanoseconds', False)) or ( + _pandas_api.is_v1() and arrow_type.id in + [_Type_DATE32, _Type_DATE64, _Type_TIMESTAMP, _Type_DURATION]) + + if getattr(arrow_type, 'tz', None): + dtype = _get_pandas_tz_type(arrow_type, coerce_to_ns) + else: + dtype = _get_pandas_type(arrow_type, coerce_to_ns) + + if not dtype: + raise NotImplementedError(str(arrow_type)) + + return dtype + + +# Workaround for Cython parsing bug +# https://github.com/cython/cython/issues/2143 +ctypedef CFixedWidthType* _CFixedWidthTypePtr + + +cdef class DataType(_Weakrefable): + """ + Base class of all Arrow data types. + + Each data type is an *instance* of this class. + + Examples + -------- + Instance of int64 type: + + >>> import pyarrow as pa + >>> pa.int64() + DataType(int64) + """ + + def __cinit__(self): + pass + + def __init__(self): + raise TypeError("Do not call {}'s constructor directly, use public " + "functions like pyarrow.int64, pyarrow.list_, etc. " + "instead.".format(self.__class__.__name__)) + + cdef void init(self, const shared_ptr[CDataType]& type) except *: + assert type != nullptr + self.sp_type = type + self.type = type.get() + self.pep3118_format = _datatype_to_pep3118(self.type) + + cpdef Field field(self, i): + """ + Parameters + ---------- + i : int + + Returns + ------- + pyarrow.Field + """ + if not isinstance(i, int): + raise TypeError(f"Expected int index, got type '{type(i)}'") + cdef int index = _normalize_index(i, self.type.num_fields()) + return pyarrow_wrap_field(self.type.field(index)) + + @property + def id(self): + return self.type.id() + + @property + def bit_width(self): + """ + Bit width for fixed width type. + + Examples + -------- + >>> import pyarrow as pa + >>> pa.int64() + DataType(int64) + >>> pa.int64().bit_width + 64 + """ + cdef _CFixedWidthTypePtr ty + ty = dynamic_cast[_CFixedWidthTypePtr](self.type) + if ty == nullptr: + raise ValueError("Non-fixed width type") + return ty.bit_width() + + @property + def byte_width(self): + """ + Byte width for fixed width type. + + Examples + -------- + >>> import pyarrow as pa + >>> pa.int64() + DataType(int64) + >>> pa.int64().byte_width + 8 + """ + cdef _CFixedWidthTypePtr ty + ty = dynamic_cast[_CFixedWidthTypePtr](self.type) + if ty == nullptr: + raise ValueError("Non-fixed width type") + byte_width = ty.byte_width() + if byte_width == 0 and self.bit_width != 0: + raise ValueError("Less than one byte") + return byte_width + + @property + def num_fields(self): + """ + The number of child fields. + + Examples + -------- + >>> import pyarrow as pa + >>> pa.int64() + DataType(int64) + >>> pa.int64().num_fields + 0 + >>> pa.list_(pa.string()) + ListType(list) + >>> pa.list_(pa.string()).num_fields + 1 + >>> struct = pa.struct({'x': pa.int32(), 'y': pa.string()}) + >>> struct.num_fields + 2 + """ + return self.type.num_fields() + + @property + def num_buffers(self): + """ + Number of data buffers required to construct Array type + excluding children. + + Examples + -------- + >>> import pyarrow as pa + >>> pa.int64().num_buffers + 2 + >>> pa.string().num_buffers + 3 + """ + return self.type.layout().buffers.size() + + @property + def has_variadic_buffers(self): + """ + If True, the number of expected buffers is only + lower-bounded by num_buffers. + + Examples + -------- + >>> import pyarrow as pa + >>> pa.int64().has_variadic_buffers + False + >>> pa.string_view().has_variadic_buffers + True + """ + return self.type.layout().variadic_spec.has_value() + + def __str__(self): + return frombytes(self.type.ToString(), safe=True) + + def __hash__(self): + return hash(str(self)) + + def __reduce__(self): + return type_for_alias, (str(self),) + + def __repr__(self): + return '{0.__class__.__name__}({0})'.format(self) + + def __eq__(self, other): + try: + return self.equals(other) + except (TypeError, ValueError): + return NotImplemented + + def equals(self, other, *, check_metadata=False): + """ + Return true if type is equivalent to passed value. + + Parameters + ---------- + other : DataType or string convertible to DataType + check_metadata : bool + Whether nested Field metadata equality should be checked as well. + + Returns + ------- + is_equal : bool + + Examples + -------- + >>> import pyarrow as pa + >>> pa.int64().equals(pa.string()) + False + >>> pa.int64().equals(pa.int64()) + True + """ + cdef: + DataType other_type + c_bool c_check_metadata + + other_type = ensure_type(other) + c_check_metadata = check_metadata + return self.type.Equals(deref(other_type.type), c_check_metadata) + + def to_pandas_dtype(self): + """ + Return the equivalent NumPy / Pandas dtype. + + Examples + -------- + >>> import pyarrow as pa + >>> pa.int64().to_pandas_dtype() + + """ + return _to_pandas_dtype(self) + + def _export_to_c(self, out_ptr): + """ + Export to a C ArrowSchema struct, given its pointer. + + Be careful: if you don't pass the ArrowSchema struct to a consumer, + its memory will leak. This is a low-level function intended for + expert users. + """ + check_status(ExportType(deref(self.type), + _as_c_pointer(out_ptr))) + + @staticmethod + def _import_from_c(in_ptr): + """ + Import DataType from a C ArrowSchema struct, given its pointer. + + This is a low-level function intended for expert users. + """ + result = GetResultValue(ImportType( + _as_c_pointer(in_ptr))) + return pyarrow_wrap_data_type(result) + + def __arrow_c_schema__(self): + """ + Export to a ArrowSchema PyCapsule + + Unlike _export_to_c, this will not leak memory if the capsule is not used. + """ + cdef ArrowSchema* c_schema + capsule = alloc_c_schema(&c_schema) + + with nogil: + check_status(ExportType(deref(self.type), c_schema)) + + return capsule + + @staticmethod + def _import_from_c_capsule(schema): + """ + Import a DataType from a ArrowSchema PyCapsule + + Parameters + ---------- + schema : PyCapsule + A valid PyCapsule with name 'arrow_schema' containing an + ArrowSchema pointer. + """ + cdef: + ArrowSchema* c_schema + shared_ptr[CDataType] c_type + + if not PyCapsule_IsValid(schema, 'arrow_schema'): + raise TypeError( + "Not an ArrowSchema object" + ) + c_schema = PyCapsule_GetPointer(schema, 'arrow_schema') + + with nogil: + c_type = GetResultValue(ImportType(c_schema)) + + return pyarrow_wrap_data_type(c_type) + + +cdef class DictionaryMemo(_Weakrefable): + """ + Tracking container for dictionary-encoded fields. + """ + + def __cinit__(self): + self.sp_memo.reset(new CDictionaryMemo()) + self.memo = self.sp_memo.get() + + +cdef class DictionaryType(DataType): + """ + Concrete class for dictionary data types. + + Examples + -------- + Create an instance of dictionary type: + + >>> import pyarrow as pa + >>> pa.dictionary(pa.int64(), pa.utf8()) + DictionaryType(dictionary) + """ + + cdef void init(self, const shared_ptr[CDataType]& type) except *: + DataType.init(self, type) + self.dict_type = type.get() + + def __reduce__(self): + return dictionary, (self.index_type, self.value_type, self.ordered) + + @property + def ordered(self): + """ + Whether the dictionary is ordered, i.e. whether the ordering of values + in the dictionary is important. + + Examples + -------- + >>> import pyarrow as pa + >>> pa.dictionary(pa.int64(), pa.utf8()).ordered + False + """ + return self.dict_type.ordered() + + @property + def index_type(self): + """ + The data type of dictionary indices (a signed integer type). + + Examples + -------- + >>> import pyarrow as pa + >>> pa.dictionary(pa.int16(), pa.utf8()).index_type + DataType(int16) + """ + return pyarrow_wrap_data_type(self.dict_type.index_type()) + + @property + def value_type(self): + """ + The dictionary value type. + + The dictionary values are found in an instance of DictionaryArray. + + Examples + -------- + >>> import pyarrow as pa + >>> pa.dictionary(pa.int16(), pa.utf8()).value_type + DataType(string) + """ + return pyarrow_wrap_data_type(self.dict_type.value_type()) + + +cdef class ListType(DataType): + """ + Concrete class for list data types. + + Examples + -------- + Create an instance of ListType: + + >>> import pyarrow as pa + >>> pa.list_(pa.string()) + ListType(list) + """ + + cdef void init(self, const shared_ptr[CDataType]& type) except *: + DataType.init(self, type) + self.list_type = type.get() + + def __reduce__(self): + return list_, (self.value_field,) + + @property + def value_field(self): + """ + The field for list values. + + Examples + -------- + >>> import pyarrow as pa + >>> pa.list_(pa.string()).value_field + pyarrow.Field + """ + return pyarrow_wrap_field(self.list_type.value_field()) + + @property + def value_type(self): + """ + The data type of list values. + + Examples + -------- + >>> import pyarrow as pa + >>> pa.list_(pa.string()).value_type + DataType(string) + """ + return pyarrow_wrap_data_type(self.list_type.value_type()) + + +cdef class LargeListType(DataType): + """ + Concrete class for large list data types + (like ListType, but with 64-bit offsets). + + Examples + -------- + Create an instance of LargeListType: + + >>> import pyarrow as pa + >>> pa.large_list(pa.string()) + LargeListType(large_list) + """ + + cdef void init(self, const shared_ptr[CDataType]& type) except *: + DataType.init(self, type) + self.list_type = type.get() + + def __reduce__(self): + return large_list, (self.value_field,) + + @property + def value_field(self): + return pyarrow_wrap_field(self.list_type.value_field()) + + @property + def value_type(self): + """ + The data type of large list values. + + Examples + -------- + >>> import pyarrow as pa + >>> pa.large_list(pa.string()).value_type + DataType(string) + """ + return pyarrow_wrap_data_type(self.list_type.value_type()) + + +cdef class ListViewType(DataType): + """ + Concrete class for list view data types. + + Examples + -------- + Create an instance of ListViewType: + + >>> import pyarrow as pa + >>> pa.list_view(pa.string()) + ListViewType(list_view) + """ + + cdef void init(self, const shared_ptr[CDataType]& type) except *: + DataType.init(self, type) + self.list_view_type = type.get() + + def __reduce__(self): + return list_view, (self.value_field,) + + @property + def value_field(self): + """ + The field for list view values. + + Examples + -------- + >>> import pyarrow as pa + >>> pa.list_view(pa.string()).value_field + pyarrow.Field + """ + return pyarrow_wrap_field(self.list_view_type.value_field()) + + @property + def value_type(self): + """ + The data type of list view values. + + Examples + -------- + >>> import pyarrow as pa + >>> pa.list_view(pa.string()).value_type + DataType(string) + """ + return pyarrow_wrap_data_type(self.list_view_type.value_type()) + + +cdef class LargeListViewType(DataType): + """ + Concrete class for large list view data types + (like ListViewType, but with 64-bit offsets). + + Examples + -------- + Create an instance of LargeListViewType: + + >>> import pyarrow as pa + >>> pa.large_list_view(pa.string()) + LargeListViewType(large_list_view) + """ + + cdef void init(self, const shared_ptr[CDataType]& type) except *: + DataType.init(self, type) + self.list_view_type = type.get() + + def __reduce__(self): + return large_list_view, (self.value_field,) + + @property + def value_field(self): + """ + The field for large list view values. + + Examples + -------- + >>> import pyarrow as pa + >>> pa.large_list_view(pa.string()).value_field + pyarrow.Field + """ + return pyarrow_wrap_field(self.list_view_type.value_field()) + + @property + def value_type(self): + """ + The data type of large list view values. + + Examples + -------- + >>> import pyarrow as pa + >>> pa.large_list_view(pa.string()).value_type + DataType(string) + """ + return pyarrow_wrap_data_type(self.list_view_type.value_type()) + + +cdef class MapType(DataType): + """ + Concrete class for map data types. + + Examples + -------- + Create an instance of MapType: + + >>> import pyarrow as pa + >>> pa.map_(pa.string(), pa.int32()) + MapType(map) + >>> pa.map_(pa.string(), pa.int32(), keys_sorted=True) + MapType(map) + """ + + cdef void init(self, const shared_ptr[CDataType]& type) except *: + DataType.init(self, type) + self.map_type = type.get() + + def __reduce__(self): + return map_, (self.key_field, self.item_field) + + @property + def key_field(self): + """ + The field for keys in the map entries. + + Examples + -------- + >>> import pyarrow as pa + >>> pa.map_(pa.string(), pa.int32()).key_field + pyarrow.Field + """ + return pyarrow_wrap_field(self.map_type.key_field()) + + @property + def key_type(self): + """ + The data type of keys in the map entries. + + Examples + -------- + >>> import pyarrow as pa + >>> pa.map_(pa.string(), pa.int32()).key_type + DataType(string) + """ + return pyarrow_wrap_data_type(self.map_type.key_type()) + + @property + def item_field(self): + """ + The field for items in the map entries. + + Examples + -------- + >>> import pyarrow as pa + >>> pa.map_(pa.string(), pa.int32()).item_field + pyarrow.Field + """ + return pyarrow_wrap_field(self.map_type.item_field()) + + @property + def item_type(self): + """ + The data type of items in the map entries. + + Examples + -------- + >>> import pyarrow as pa + >>> pa.map_(pa.string(), pa.int32()).item_type + DataType(int32) + """ + return pyarrow_wrap_data_type(self.map_type.item_type()) + + @property + def keys_sorted(self): + """ + Should the entries be sorted according to keys. + + Examples + -------- + >>> import pyarrow as pa + >>> pa.map_(pa.string(), pa.int32(), keys_sorted=True).keys_sorted + True + """ + return self.map_type.keys_sorted() + + +cdef class FixedSizeListType(DataType): + """ + Concrete class for fixed size list data types. + + Examples + -------- + Create an instance of FixedSizeListType: + + >>> import pyarrow as pa + >>> pa.list_(pa.int32(), 2) + FixedSizeListType(fixed_size_list[2]) + """ + + cdef void init(self, const shared_ptr[CDataType]& type) except *: + DataType.init(self, type) + self.list_type = type.get() + + def __reduce__(self): + return list_, (self.value_type, self.list_size) + + @property + def value_field(self): + """ + The field for list values. + + Examples + -------- + >>> import pyarrow as pa + >>> pa.list_(pa.int32(), 2).value_field + pyarrow.Field + """ + return pyarrow_wrap_field(self.list_type.value_field()) + + @property + def value_type(self): + """ + The data type of large list values. + + Examples + -------- + >>> import pyarrow as pa + >>> pa.list_(pa.int32(), 2).value_type + DataType(int32) + """ + return pyarrow_wrap_data_type(self.list_type.value_type()) + + @property + def list_size(self): + """ + The size of the fixed size lists. + + Examples + -------- + >>> import pyarrow as pa + >>> pa.list_(pa.int32(), 2).list_size + 2 + """ + return self.list_type.list_size() + + +cdef class StructType(DataType): + """ + Concrete class for struct data types. + + ``StructType`` supports direct indexing using ``[...]`` (implemented via + ``__getitem__``) to access its fields. + It will return the struct field with the given index or name. + + Examples + -------- + >>> import pyarrow as pa + + Accessing fields using direct indexing: + + >>> struct_type = pa.struct({'x': pa.int32(), 'y': pa.string()}) + >>> struct_type[0] + pyarrow.Field + >>> struct_type['y'] + pyarrow.Field + + Accessing fields using ``field()``: + + >>> struct_type.field(1) + pyarrow.Field + >>> struct_type.field('x') + pyarrow.Field + + # Creating a schema from the struct type's fields: + >>> pa.schema(list(struct_type)) + x: int32 + y: string + """ + + cdef void init(self, const shared_ptr[CDataType]& type) except *: + DataType.init(self, type) + self.struct_type = type.get() + + cdef Field field_by_name(self, name): + """ + Return a child field by its name. + + Parameters + ---------- + name : str + The name of the field to look up. + + Returns + ------- + field : Field + The child field with the given name. + + Raises + ------ + KeyError + If the name isn't found, or if several fields have the given + name. + """ + cdef vector[shared_ptr[CField]] fields + + fields = self.struct_type.GetAllFieldsByName(tobytes(name)) + if fields.size() == 0: + raise KeyError(name) + elif fields.size() > 1: + warnings.warn("Struct field name corresponds to more " + "than one field", UserWarning) + raise KeyError(name) + else: + return pyarrow_wrap_field(fields[0]) + + def get_field_index(self, name): + """ + Return index of the unique field with the given name. + + Parameters + ---------- + name : str + The name of the field to look up. + + Returns + ------- + index : int + The index of the field with the given name; -1 if the + name isn't found or there are several fields with the given + name. + + Examples + -------- + >>> import pyarrow as pa + >>> struct_type = pa.struct({'x': pa.int32(), 'y': pa.string()}) + + Index of the field with a name 'y': + + >>> struct_type.get_field_index('y') + 1 + + Index of the field that does not exist: + + >>> struct_type.get_field_index('z') + -1 + """ + return self.struct_type.GetFieldIndex(tobytes(name)) + + cpdef Field field(self, i): + """ + Select a field by its column name or numeric index. + + Parameters + ---------- + i : int or str + + Returns + ------- + pyarrow.Field + + Examples + -------- + + >>> import pyarrow as pa + >>> struct_type = pa.struct({'x': pa.int32(), 'y': pa.string()}) + + Select the second field: + + >>> struct_type.field(1) + pyarrow.Field + + Select the field named 'x': + + >>> struct_type.field('x') + pyarrow.Field + """ + if isinstance(i, (bytes, str)): + return self.field_by_name(i) + elif isinstance(i, int): + return DataType.field(self, i) + else: + raise TypeError('Expected integer or string index') + + def get_all_field_indices(self, name): + """ + Return sorted list of indices for the fields with the given name. + + Parameters + ---------- + name : str + The name of the field to look up. + + Returns + ------- + indices : List[int] + + Examples + -------- + >>> import pyarrow as pa + >>> struct_type = pa.struct({'x': pa.int32(), 'y': pa.string()}) + >>> struct_type.get_all_field_indices('x') + [0] + """ + return self.struct_type.GetAllFieldIndices(tobytes(name)) + + def __len__(self): + """ + Like num_fields(). + """ + return self.type.num_fields() + + def __iter__(self): + """ + Iterate over struct fields, in order. + """ + for i in range(len(self)): + yield self[i] + + def __getitem__(self, i): + """ + Return the struct field with the given index or name. + + Alias of ``field``. + """ + return self.field(i) + + def __reduce__(self): + return struct, (list(self),) + + @property + def names(self): + """ + Lists the field names. + + Examples + -------- + >>> import pyarrow as pa + >>> struct_type = pa.struct([('a', pa.int64()), ('b', pa.float64()), ('c', pa.string())]) + >>> struct_type.names + ['a', 'b', 'c'] + """ + return [f.name for f in self] + + @property + def fields(self): + """ + Lists all fields within the StructType. + + Examples + -------- + >>> import pyarrow as pa + >>> struct_type = pa.struct([('a', pa.int64()), ('b', pa.float64()), ('c', pa.string())]) + >>> struct_type.fields + [pyarrow.Field, pyarrow.Field, pyarrow.Field] + """ + return list(self) + +cdef class UnionType(DataType): + """ + Base class for union data types. + + Examples + -------- + Create an instance of a dense UnionType using ``pa.union``: + + >>> import pyarrow as pa + >>> pa.union([pa.field('a', pa.binary(10)), pa.field('b', pa.string())], + ... mode=pa.lib.UnionMode_DENSE), + (DenseUnionType(dense_union),) + + Create an instance of a dense UnionType using ``pa.dense_union``: + + >>> pa.dense_union([pa.field('a', pa.binary(10)), pa.field('b', pa.string())]) + DenseUnionType(dense_union) + + Create an instance of a sparse UnionType using ``pa.union``: + + >>> pa.union([pa.field('a', pa.binary(10)), pa.field('b', pa.string())], + ... mode=pa.lib.UnionMode_SPARSE), + (SparseUnionType(sparse_union),) + + Create an instance of a sparse UnionType using ``pa.sparse_union``: + + >>> pa.sparse_union([pa.field('a', pa.binary(10)), pa.field('b', pa.string())]) + SparseUnionType(sparse_union) + """ + + cdef void init(self, const shared_ptr[CDataType]& type) except *: + DataType.init(self, type) + + @property + def mode(self): + """ + The mode of the union ("dense" or "sparse"). + + Examples + -------- + >>> import pyarrow as pa + >>> union = pa.sparse_union([pa.field('a', pa.binary(10)), pa.field('b', pa.string())]) + >>> union.mode + 'sparse' + """ + cdef CUnionType* type = self.sp_type.get() + cdef int mode = type.mode() + if mode == _UnionMode_DENSE: + return 'dense' + if mode == _UnionMode_SPARSE: + return 'sparse' + assert 0 + + @property + def type_codes(self): + """ + The type code to indicate each data type in this union. + + Examples + -------- + >>> import pyarrow as pa + >>> union = pa.sparse_union([pa.field('a', pa.binary(10)), pa.field('b', pa.string())]) + >>> union.type_codes + [0, 1] + """ + cdef CUnionType* type = self.sp_type.get() + return type.type_codes() + + def __len__(self): + """ + Like num_fields(). + """ + return self.type.num_fields() + + def __iter__(self): + """ + Iterate over union members, in order. + """ + for i in range(len(self)): + yield self[i] + + cpdef Field field(self, i): + """ + Return a child field by its numeric index. + + Parameters + ---------- + i : int + + Returns + ------- + pyarrow.Field + + Examples + -------- + >>> import pyarrow as pa + >>> union = pa.sparse_union([pa.field('a', pa.binary(10)), pa.field('b', pa.string())]) + >>> union[0] + pyarrow.Field + """ + if isinstance(i, int): + return DataType.field(self, i) + else: + raise TypeError('Expected integer') + + def __getitem__(self, i): + """ + Return a child field by its index. + + Alias of ``field``. + """ + return self.field(i) + + def __reduce__(self): + return union, (list(self), self.mode, self.type_codes) + + +cdef class SparseUnionType(UnionType): + """ + Concrete class for sparse union types. + + Examples + -------- + Create an instance of a sparse UnionType using ``pa.union``: + + >>> import pyarrow as pa + >>> pa.union([pa.field('a', pa.binary(10)), pa.field('b', pa.string())], + ... mode=pa.lib.UnionMode_SPARSE), + (SparseUnionType(sparse_union),) + + Create an instance of a sparse UnionType using ``pa.sparse_union``: + + >>> pa.sparse_union([pa.field('a', pa.binary(10)), pa.field('b', pa.string())]) + SparseUnionType(sparse_union) + """ + + +cdef class DenseUnionType(UnionType): + """ + Concrete class for dense union types. + + Examples + -------- + Create an instance of a dense UnionType using ``pa.union``: + + >>> import pyarrow as pa + >>> pa.union([pa.field('a', pa.binary(10)), pa.field('b', pa.string())], + ... mode=pa.lib.UnionMode_DENSE), + (DenseUnionType(dense_union),) + + Create an instance of a dense UnionType using ``pa.dense_union``: + + >>> pa.dense_union([pa.field('a', pa.binary(10)), pa.field('b', pa.string())]) + DenseUnionType(dense_union) + """ + + +cdef class TimestampType(DataType): + """ + Concrete class for timestamp data types. + + Examples + -------- + >>> import pyarrow as pa + + Create an instance of timestamp type: + + >>> pa.timestamp('us') + TimestampType(timestamp[us]) + + Create an instance of timestamp type with timezone: + + >>> pa.timestamp('s', tz='UTC') + TimestampType(timestamp[s, tz=UTC]) + """ + + cdef void init(self, const shared_ptr[CDataType]& type) except *: + DataType.init(self, type) + self.ts_type = type.get() + + @property + def unit(self): + """ + The timestamp unit ('s', 'ms', 'us' or 'ns'). + + Examples + -------- + >>> import pyarrow as pa + >>> t = pa.timestamp('us') + >>> t.unit + 'us' + """ + return timeunit_to_string(self.ts_type.unit()) + + @property + def tz(self): + """ + The timestamp time zone, if any, or None. + + Examples + -------- + >>> import pyarrow as pa + >>> t = pa.timestamp('s', tz='UTC') + >>> t.tz + 'UTC' + """ + if self.ts_type.timezone().size() > 0: + return frombytes(self.ts_type.timezone()) + else: + return None + + def __reduce__(self): + return timestamp, (self.unit, self.tz) + + +cdef class Time32Type(DataType): + """ + Concrete class for time32 data types. + + Supported time unit resolutions are 's' [second] + and 'ms' [millisecond]. + + Examples + -------- + Create an instance of time32 type: + + >>> import pyarrow as pa + >>> pa.time32('ms') + Time32Type(time32[ms]) + """ + + cdef void init(self, const shared_ptr[CDataType]& type) except *: + DataType.init(self, type) + self.time_type = type.get() + + @property + def unit(self): + """ + The time unit ('s' or 'ms'). + + Examples + -------- + >>> import pyarrow as pa + >>> t = pa.time32('ms') + >>> t.unit + 'ms' + """ + return timeunit_to_string(self.time_type.unit()) + + +cdef class Time64Type(DataType): + """ + Concrete class for time64 data types. + + Supported time unit resolutions are 'us' [microsecond] + and 'ns' [nanosecond]. + + Examples + -------- + Create an instance of time64 type: + + >>> import pyarrow as pa + >>> pa.time64('us') + Time64Type(time64[us]) + """ + + cdef void init(self, const shared_ptr[CDataType]& type) except *: + DataType.init(self, type) + self.time_type = type.get() + + @property + def unit(self): + """ + The time unit ('us' or 'ns'). + + Examples + -------- + >>> import pyarrow as pa + >>> t = pa.time64('us') + >>> t.unit + 'us' + """ + return timeunit_to_string(self.time_type.unit()) + + +cdef class DurationType(DataType): + """ + Concrete class for duration data types. + + Examples + -------- + Create an instance of duration type: + + >>> import pyarrow as pa + >>> pa.duration('s') + DurationType(duration[s]) + """ + + cdef void init(self, const shared_ptr[CDataType]& type) except *: + DataType.init(self, type) + self.duration_type = type.get() + + @property + def unit(self): + """ + The duration unit ('s', 'ms', 'us' or 'ns'). + + Examples + -------- + >>> import pyarrow as pa + >>> t = pa.duration('s') + >>> t.unit + 's' + """ + return timeunit_to_string(self.duration_type.unit()) + + +cdef class FixedSizeBinaryType(DataType): + """ + Concrete class for fixed-size binary data types. + + Examples + -------- + Create an instance of fixed-size binary type: + + >>> import pyarrow as pa + >>> pa.binary(3) + FixedSizeBinaryType(fixed_size_binary[3]) + """ + + cdef void init(self, const shared_ptr[CDataType]& type) except *: + DataType.init(self, type) + self.fixed_size_binary_type = ( + type.get()) + + def __reduce__(self): + return binary, (self.byte_width,) + + +cdef class Decimal32Type(FixedSizeBinaryType): + """ + Concrete class for decimal32 data types. + + Examples + -------- + Create an instance of decimal32 type: + + >>> import pyarrow as pa + >>> pa.decimal32(5, 2) + Decimal32Type(decimal32(5, 2)) + """ + + cdef void init(self, const shared_ptr[CDataType]& type) except *: + FixedSizeBinaryType.init(self, type) + self.decimal32_type = type.get() + + def __reduce__(self): + return decimal32, (self.precision, self.scale) + + @property + def precision(self): + """ + The decimal precision, in number of decimal digits (an integer). + + Examples + -------- + >>> import pyarrow as pa + >>> t = pa.decimal32(5, 2) + >>> t.precision + 5 + """ + return self.decimal32_type.precision() + + @property + def scale(self): + """ + The decimal scale (an integer). + + Examples + -------- + >>> import pyarrow as pa + >>> t = pa.decimal32(5, 2) + >>> t.scale + 2 + """ + return self.decimal32_type.scale() + + +cdef class Decimal64Type(FixedSizeBinaryType): + """ + Concrete class for decimal64 data types. + + Examples + -------- + Create an instance of decimal64 type: + + >>> import pyarrow as pa + >>> pa.decimal64(5, 2) + Decimal64Type(decimal64(5, 2)) + """ + + cdef void init(self, const shared_ptr[CDataType]& type) except *: + FixedSizeBinaryType.init(self, type) + self.decimal64_type = type.get() + + def __reduce__(self): + return decimal64, (self.precision, self.scale) + + @property + def precision(self): + """ + The decimal precision, in number of decimal digits (an integer). + + Examples + -------- + >>> import pyarrow as pa + >>> t = pa.decimal64(5, 2) + >>> t.precision + 5 + """ + return self.decimal64_type.precision() + + @property + def scale(self): + """ + The decimal scale (an integer). + + Examples + -------- + >>> import pyarrow as pa + >>> t = pa.decimal64(5, 2) + >>> t.scale + 2 + """ + return self.decimal64_type.scale() + + +cdef class Decimal128Type(FixedSizeBinaryType): + """ + Concrete class for decimal128 data types. + + Examples + -------- + Create an instance of decimal128 type: + + >>> import pyarrow as pa + >>> pa.decimal128(5, 2) + Decimal128Type(decimal128(5, 2)) + """ + + cdef void init(self, const shared_ptr[CDataType]& type) except *: + FixedSizeBinaryType.init(self, type) + self.decimal128_type = type.get() + + def __reduce__(self): + return decimal128, (self.precision, self.scale) + + @property + def precision(self): + """ + The decimal precision, in number of decimal digits (an integer). + + Examples + -------- + >>> import pyarrow as pa + >>> t = pa.decimal128(5, 2) + >>> t.precision + 5 + """ + return self.decimal128_type.precision() + + @property + def scale(self): + """ + The decimal scale (an integer). + + Examples + -------- + >>> import pyarrow as pa + >>> t = pa.decimal128(5, 2) + >>> t.scale + 2 + """ + return self.decimal128_type.scale() + + +cdef class Decimal256Type(FixedSizeBinaryType): + """ + Concrete class for decimal256 data types. + + Examples + -------- + Create an instance of decimal256 type: + + >>> import pyarrow as pa + >>> pa.decimal256(76, 38) + Decimal256Type(decimal256(76, 38)) + """ + + cdef void init(self, const shared_ptr[CDataType]& type) except *: + FixedSizeBinaryType.init(self, type) + self.decimal256_type = type.get() + + def __reduce__(self): + return decimal256, (self.precision, self.scale) + + @property + def precision(self): + """ + The decimal precision, in number of decimal digits (an integer). + + Examples + -------- + >>> import pyarrow as pa + >>> t = pa.decimal256(76, 38) + >>> t.precision + 76 + """ + return self.decimal256_type.precision() + + @property + def scale(self): + """ + The decimal scale (an integer). + + Examples + -------- + >>> import pyarrow as pa + >>> t = pa.decimal256(76, 38) + >>> t.scale + 38 + """ + return self.decimal256_type.scale() + + +cdef class RunEndEncodedType(DataType): + """ + Concrete class for run-end encoded types. + """ + + cdef void init(self, const shared_ptr[CDataType]& type) except *: + DataType.init(self, type) + self.run_end_encoded_type = type.get() + + def __reduce__(self): + return run_end_encoded, (self.run_end_type, self.value_type) + + @property + def run_end_type(self): + return pyarrow_wrap_data_type(self.run_end_encoded_type.run_end_type()) + + @property + def value_type(self): + return pyarrow_wrap_data_type(self.run_end_encoded_type.value_type()) + + +cdef class BaseExtensionType(DataType): + """ + Concrete base class for extension types. + """ + + cdef void init(self, const shared_ptr[CDataType]& type) except *: + DataType.init(self, type) + self.ext_type = type.get() + + def __arrow_ext_class__(self): + """ + The associated array extension class + """ + return ExtensionArray + + def __arrow_ext_scalar_class__(self): + """ + The associated scalar class + """ + return ExtensionScalar + + @property + def extension_name(self): + """ + The extension type name. + """ + return frombytes(self.ext_type.extension_name()) + + @property + def storage_type(self): + """ + The underlying storage type. + """ + return pyarrow_wrap_data_type(self.ext_type.storage_type()) + + @property + def byte_width(self): + """ + The byte width of the extension type. + """ + if self.ext_type.byte_width() == -1: + raise ValueError("Non-fixed width type") + return self.ext_type.byte_width() + + @property + def bit_width(self): + """ + The bit width of the extension type. + """ + if self.ext_type.bit_width() == -1: + raise ValueError("Non-fixed width type") + return self.ext_type.bit_width() + + def wrap_array(self, storage): + """ + Wrap the given storage array as an extension array. + + Parameters + ---------- + storage : Array or ChunkedArray + + Returns + ------- + array : Array or ChunkedArray + Extension array wrapping the storage array + """ + cdef: + shared_ptr[CDataType] c_storage_type + + if isinstance(storage, Array): + c_storage_type = ( storage).ap.type() + elif isinstance(storage, ChunkedArray): + c_storage_type = ( storage).chunked_array.type() + else: + raise TypeError( + f"Expected array or chunked array, got {storage.__class__}") + + if not c_storage_type.get().Equals(deref(self.ext_type) + .storage_type(), False): + raise TypeError( + f"Incompatible storage type for {self}: " + f"expected {self.storage_type}, got {storage.type}") + + if isinstance(storage, Array): + return pyarrow_wrap_array( + self.ext_type.WrapArray( + self.sp_type, ( storage).sp_array)) + else: + return pyarrow_wrap_chunked_array( + self.ext_type.WrapArray( + self.sp_type, ( storage).sp_chunked_array)) + + +cdef class ExtensionType(BaseExtensionType): + """ + Concrete base class for Python-defined extension types. + + Parameters + ---------- + storage_type : DataType + The underlying storage type for the extension type. + extension_name : str + A unique name distinguishing this extension type. The name will be + used when deserializing IPC data. + + Examples + -------- + Define a RationalType extension type subclassing ExtensionType: + + >>> import pyarrow as pa + >>> class RationalType(pa.ExtensionType): + ... def __init__(self, data_type: pa.DataType): + ... if not pa.types.is_integer(data_type): + ... raise TypeError(f"data_type must be an integer type not {data_type}") + ... super().__init__( + ... pa.struct( + ... [ + ... ("numer", data_type), + ... ("denom", data_type), + ... ], + ... ), + ... # N.B. This name does _not_ reference `data_type` so deserialization + ... # will work for _any_ integer `data_type` after registration + ... "my_package.rational", + ... ) + ... def __arrow_ext_serialize__(self) -> bytes: + ... # No parameters are necessary + ... return b"" + ... @classmethod + ... def __arrow_ext_deserialize__(cls, storage_type, serialized): + ... # return an instance of this subclass + ... return RationalType(storage_type[0].type) + + Register the extension type: + + >>> pa.register_extension_type(RationalType(pa.int64())) + + Create an instance of RationalType extension type: + + >>> rational_type = RationalType(pa.int32()) + + Inspect the extension type: + + >>> rational_type.extension_name + 'my_package.rational' + >>> rational_type.storage_type + StructType(struct) + + Wrap an array as an extension array: + + >>> storage_array = pa.array( + ... [ + ... {"numer": 10, "denom": 17}, + ... {"numer": 20, "denom": 13}, + ... ], + ... type=rational_type.storage_type + ... ) + >>> rational_array = rational_type.wrap_array(storage_array) + >>> rational_array + + -- is_valid: all not null + -- child 0 type: int32 + [ + 10, + 20 + ] + -- child 1 type: int32 + [ + 17, + 13 + ] + + Or do the same with creating an ExtensionArray: + + >>> rational_array = pa.ExtensionArray.from_storage(rational_type, storage_array) + >>> rational_array + + -- is_valid: all not null + -- child 0 type: int32 + [ + 10, + 20 + ] + -- child 1 type: int32 + [ + 17, + 13 + ] + + Unregister the extension type: + + >>> pa.unregister_extension_type("my_package.rational") + + Note that even though we registered the concrete type + ``RationalType(pa.int64())``, PyArrow will be able to deserialize + ``RationalType(integer_type)`` for any ``integer_type``, as the deserializer + will reference the name ``my_package.rational`` and the ``@classmethod`` + ``__arrow_ext_deserialize__``. + """ + + def __cinit__(self): + if type(self) is ExtensionType: + raise TypeError("Can only instantiate subclasses of " + "ExtensionType") + + def __init__(self, DataType storage_type, extension_name): + """ + Initialize an extension type instance. + + This should be called at the end of the subclass' + ``__init__`` method. + """ + cdef: + shared_ptr[CExtensionType] cpy_ext_type + c_string c_extension_name + + c_extension_name = tobytes(extension_name) + + assert storage_type is not None + check_status(CPyExtensionType.FromClass( + storage_type.sp_type, c_extension_name, type(self), + &cpy_ext_type)) + self.init( cpy_ext_type) + + cdef void init(self, const shared_ptr[CDataType]& type) except *: + BaseExtensionType.init(self, type) + self.cpy_ext_type = type.get() + # Store weakref and serialized version of self on C++ type instance + check_status(self.cpy_ext_type.SetInstance(self)) + + def __eq__(self, other): + # Default implementation to avoid infinite recursion through + # DataType.__eq__ -> ExtensionType::ExtensionEquals -> DataType.__eq__ + if isinstance(other, ExtensionType): + return (type(self) == type(other) and + self.extension_name == other.extension_name and + self.storage_type == other.storage_type) + else: + return NotImplemented + + def __repr__(self): + fmt = '{0.__class__.__name__}({1})' + return fmt.format(self, repr(self.storage_type)) + + def __arrow_ext_serialize__(self): + """ + Serialized representation of metadata to reconstruct the type object. + + This method should return a bytes object, and those serialized bytes + are stored in the custom metadata of the Field holding an extension + type in an IPC message. + The bytes are passed to ``__arrow_ext_deserialize`` and should hold + sufficient information to reconstruct the data type instance. + """ + return NotImplementedError + + @classmethod + def __arrow_ext_deserialize__(cls, storage_type, serialized): + """ + Return an extension type instance from the storage type and serialized + metadata. + + This method should return an instance of the ExtensionType subclass + that matches the passed storage type and serialized metadata (the + return value of ``__arrow_ext_serialize__``). + """ + return NotImplementedError + + def __reduce__(self): + return self.__arrow_ext_deserialize__, (self.storage_type, self.__arrow_ext_serialize__()) + + def __arrow_ext_class__(self): + """Return an extension array class to be used for building or + deserializing arrays with this extension type. + + This method should return a subclass of the ExtensionArray class. By + default, if not specialized in the extension implementation, an + extension type array will be a built-in ExtensionArray instance. + """ + return ExtensionArray + + def __arrow_ext_scalar_class__(self): + """Return an extension scalar class for building scalars with this + extension type. + + This method should return subclass of the ExtensionScalar class. By + default, if not specialized in the extension implementation, an + extension type scalar will be a built-in ExtensionScalar instance. + """ + return ExtensionScalar + + +cdef class JsonType(BaseExtensionType): + """ + Concrete class for JSON extension type. + + Examples + -------- + Define the extension type for JSON array + + >>> import pyarrow as pa + >>> json_type = pa.json_(pa.large_utf8()) + + Create an extension array + + >>> arr = [None, '{ "id":30, "values":["a", "b"] }'] + >>> storage = pa.array(arr, pa.large_utf8()) + >>> pa.ExtensionArray.from_storage(json_type, storage) + + [ + null, + "{ "id":30, "values":["a", "b"] }" + ] + """ + + cdef void init(self, const shared_ptr[CDataType]& type) except *: + BaseExtensionType.init(self, type) + self.json_ext_type = type.get() + + def __arrow_ext_class__(self): + return JsonArray + + def __reduce__(self): + return json_, (self.storage_type,) + + def __arrow_ext_scalar_class__(self): + return JsonScalar + + +cdef class UuidType(BaseExtensionType): + """ + Concrete class for UUID extension type. + """ + + cdef void init(self, const shared_ptr[CDataType]& type) except *: + BaseExtensionType.init(self, type) + self.uuid_ext_type = type.get() + + def __arrow_ext_class__(self): + return UuidArray + + def __reduce__(self): + return uuid, () + + def __arrow_ext_scalar_class__(self): + return UuidScalar + + +cdef class FixedShapeTensorType(BaseExtensionType): + """ + Concrete class for fixed shape tensor extension type. + + Examples + -------- + Create an instance of fixed shape tensor extension type: + + >>> import pyarrow as pa + >>> pa.fixed_shape_tensor(pa.int32(), [2, 2]) + FixedShapeTensorType(extension) + + Create an instance of fixed shape tensor extension type with + permutation: + + >>> tensor_type = pa.fixed_shape_tensor(pa.int8(), (2, 2, 3), + ... permutation=[0, 2, 1]) + >>> tensor_type.permutation + [0, 2, 1] + """ + + cdef void init(self, const shared_ptr[CDataType]& type) except *: + BaseExtensionType.init(self, type) + self.tensor_ext_type = type.get() + + @property + def value_type(self): + """ + Data type of an individual tensor. + """ + return pyarrow_wrap_data_type(self.tensor_ext_type.value_type()) + + @property + def shape(self): + """ + Shape of the tensors. + """ + return self.tensor_ext_type.shape() + + @property + def dim_names(self): + """ + Explicit names of the dimensions. + """ + list_of_bytes = self.tensor_ext_type.dim_names() + if len(list_of_bytes) != 0: + return [frombytes(x) for x in list_of_bytes] + else: + return None + + @property + def permutation(self): + """ + Indices of the dimensions ordering. + """ + indices = self.tensor_ext_type.permutation() + if len(indices) != 0: + return indices + else: + return None + + def __arrow_ext_class__(self): + return FixedShapeTensorArray + + def __reduce__(self): + return fixed_shape_tensor, (self.value_type, self.shape, + self.dim_names, self.permutation) + + def __arrow_ext_scalar_class__(self): + return FixedShapeTensorScalar + + +cdef class Bool8Type(BaseExtensionType): + """ + Concrete class for bool8 extension type. + + Bool8 is an alternate representation for boolean + arrays using 8 bits instead of 1 bit per value. The underlying + storage type is int8. + + Examples + -------- + Create an instance of bool8 extension type: + + >>> import pyarrow as pa + >>> pa.bool8() + Bool8Type(extension) + """ + + cdef void init(self, const shared_ptr[CDataType]& type) except *: + BaseExtensionType.init(self, type) + self.bool8_ext_type = type.get() + + def __arrow_ext_class__(self): + return Bool8Array + + def __reduce__(self): + return bool8, () + + def __arrow_ext_scalar_class__(self): + return Bool8Scalar + + +cdef class OpaqueType(BaseExtensionType): + """ + Concrete class for opaque extension type. + + Opaque is a placeholder for a type from an external (often non-Arrow) + system that could not be interpreted. + + Examples + -------- + Create an instance of opaque extension type: + + >>> import pyarrow as pa + >>> pa.opaque(pa.int32(), "geometry", "postgis") + OpaqueType(extension) + """ + + cdef void init(self, const shared_ptr[CDataType]& type) except *: + BaseExtensionType.init(self, type) + self.opaque_ext_type = type.get() + + @property + def type_name(self): + """ + The name of the type in the external system. + """ + return frombytes(c_string(self.opaque_ext_type.type_name())) + + @property + def vendor_name(self): + """ + The name of the external system. + """ + return frombytes(c_string(self.opaque_ext_type.vendor_name())) + + def __arrow_ext_class__(self): + return OpaqueArray + + def __reduce__(self): + return opaque, (self.storage_type, self.type_name, self.vendor_name) + + def __arrow_ext_scalar_class__(self): + return OpaqueScalar + + +_py_extension_type_auto_load = False + + +cdef class PyExtensionType(ExtensionType): + """ + Concrete base class for Python-defined extension types based on pickle + for (de)serialization. + + .. warning:: + This class is deprecated and its deserialization is disabled by default. + :class:`ExtensionType` is recommended instead. + + Parameters + ---------- + storage_type : DataType + The storage type for which the extension is built. + """ + + def __cinit__(self): + if type(self) is PyExtensionType: + raise TypeError("Can only instantiate subclasses of " + "PyExtensionType") + + def __init__(self, DataType storage_type): + warnings.warn( + "pyarrow.PyExtensionType is deprecated " + "and will refuse deserialization by default. " + "Instead, please derive from pyarrow.ExtensionType and implement " + "your own serialization mechanism.", + FutureWarning) + ExtensionType.__init__(self, storage_type, "arrow.py_extension_type") + + def __reduce__(self): + raise NotImplementedError("Please implement {0}.__reduce__" + .format(type(self).__name__)) + + def __arrow_ext_serialize__(self): + return pickle.dumps(self) + + @classmethod + def __arrow_ext_deserialize__(cls, storage_type, serialized): + if not _py_extension_type_auto_load: + warnings.warn( + "pickle-based deserialization of pyarrow.PyExtensionType subclasses " + "is disabled by default; if you only ingest " + "trusted data files, you may re-enable this using " + "`pyarrow.PyExtensionType.set_auto_load(True)`.\n" + "In the future, Python-defined extension subclasses should " + "derive from pyarrow.ExtensionType (not pyarrow.PyExtensionType) " + "and implement their own serialization mechanism.\n", + RuntimeWarning) + return UnknownExtensionType(storage_type, serialized) + try: + ty = pickle.loads(serialized) + except Exception: + # For some reason, it's impossible to deserialize the + # ExtensionType instance. Perhaps the serialized data is + # corrupt, or more likely the type is being deserialized + # in an environment where the original Python class or module + # is not available. Fall back on a generic BaseExtensionType. + return UnknownExtensionType(storage_type, serialized) + + if ty.storage_type != storage_type: + raise TypeError("Expected storage type {0} but got {1}" + .format(ty.storage_type, storage_type)) + return ty + + # XXX Cython marks extension types as immutable, so cannot expose this + # as a writable class attribute. + @classmethod + def set_auto_load(cls, value): + """ + Enable or disable auto-loading of serialized PyExtensionType instances. + + Parameters + ---------- + value : bool + Whether to enable auto-loading. + """ + global _py_extension_type_auto_load + assert isinstance(value, bool) + _py_extension_type_auto_load = value + + +cdef class UnknownExtensionType(PyExtensionType): + """ + A concrete class for Python-defined extension types that refer to + an unknown Python implementation. + + Parameters + ---------- + storage_type : DataType + The storage type for which the extension is built. + serialized : bytes + The serialised output. + """ + + cdef: + bytes serialized + + def __init__(self, DataType storage_type, serialized): + self.serialized = serialized + PyExtensionType.__init__(self, storage_type) + + def __arrow_ext_serialize__(self): + return self.serialized + + +_python_extension_types_registry = [] + + +def register_extension_type(ext_type): + """ + Register a Python extension type. + + Registration is based on the extension name (so different registered types + need unique extension names). Registration needs an extension type + instance, but then works for any instance of the same subclass regardless + of parametrization of the type. + + Parameters + ---------- + ext_type : BaseExtensionType instance + The ExtensionType subclass to register. + + Examples + -------- + Define a RationalType extension type subclassing ExtensionType: + + >>> import pyarrow as pa + >>> class RationalType(pa.ExtensionType): + ... def __init__(self, data_type: pa.DataType): + ... if not pa.types.is_integer(data_type): + ... raise TypeError(f"data_type must be an integer type not {data_type}") + ... super().__init__( + ... pa.struct( + ... [ + ... ("numer", data_type), + ... ("denom", data_type), + ... ], + ... ), + ... # N.B. This name does _not_ reference `data_type` so deserialization + ... # will work for _any_ integer `data_type` after registration + ... "my_package.rational", + ... ) + ... def __arrow_ext_serialize__(self) -> bytes: + ... # No parameters are necessary + ... return b"" + ... @classmethod + ... def __arrow_ext_deserialize__(cls, storage_type, serialized): + ... # return an instance of this subclass + ... return RationalType(storage_type[0].type) + + Register the extension type: + + >>> pa.register_extension_type(RationalType(pa.int64())) + + Unregister the extension type: + + >>> pa.unregister_extension_type("my_package.rational") + """ + cdef: + DataType _type = ensure_type(ext_type, allow_none=False) + + if not isinstance(_type, BaseExtensionType): + raise TypeError("Only extension types can be registered") + + # register on the C++ side + check_status( + RegisterPyExtensionType( _type.sp_type)) + + # register on the python side + _python_extension_types_registry.append(_type) + + +def unregister_extension_type(type_name): + """ + Unregister a Python extension type. + + Parameters + ---------- + type_name : str + The name of the ExtensionType subclass to unregister. + + Examples + -------- + Define a RationalType extension type subclassing ExtensionType: + + >>> import pyarrow as pa + >>> class RationalType(pa.ExtensionType): + ... def __init__(self, data_type: pa.DataType): + ... if not pa.types.is_integer(data_type): + ... raise TypeError(f"data_type must be an integer type not {data_type}") + ... super().__init__( + ... pa.struct( + ... [ + ... ("numer", data_type), + ... ("denom", data_type), + ... ], + ... ), + ... # N.B. This name does _not_ reference `data_type` so deserialization + ... # will work for _any_ integer `data_type` after registration + ... "my_package.rational", + ... ) + ... def __arrow_ext_serialize__(self) -> bytes: + ... # No parameters are necessary + ... return b"" + ... @classmethod + ... def __arrow_ext_deserialize__(cls, storage_type, serialized): + ... # return an instance of this subclass + ... return RationalType(storage_type[0].type) + + Register the extension type: + + >>> pa.register_extension_type(RationalType(pa.int64())) + + Unregister the extension type: + + >>> pa.unregister_extension_type("my_package.rational") + """ + cdef: + c_string c_type_name = tobytes(type_name) + check_status(UnregisterPyExtensionType(c_type_name)) + + +cdef class KeyValueMetadata(_Metadata, Mapping): + """ + KeyValueMetadata + + Parameters + ---------- + __arg0__ : dict + A dict of the key-value metadata + **kwargs : optional + additional key-value metadata + """ + + def __init__(self, __arg0__=None, **kwargs): + cdef: + vector[c_string] keys, values + shared_ptr[const CKeyValueMetadata] result + + items = [] + if __arg0__ is not None: + other = (__arg0__.items() if isinstance(__arg0__, Mapping) + else __arg0__) + items.extend((tobytes(k), v) for k, v in other) + + prior_keys = {k for k, v in items} + for k, v in kwargs.items(): + k = tobytes(k) + if k in prior_keys: + raise KeyError("Duplicate key {}, " + "use pass all items as list of tuples if you " + "intend to have duplicate keys") + items.append((k, v)) + + keys.reserve(len(items)) + for key, value in items: + keys.push_back(tobytes(key)) + values.push_back(tobytes(value)) + result.reset(new CKeyValueMetadata(move(keys), move(values))) + self.init(result) + + cdef void init(self, const shared_ptr[const CKeyValueMetadata]& wrapped): + self.wrapped = wrapped + self.metadata = wrapped.get() + + @staticmethod + cdef wrap(const shared_ptr[const CKeyValueMetadata]& sp): + cdef KeyValueMetadata self = KeyValueMetadata.__new__(KeyValueMetadata) + self.init(sp) + return self + + cdef inline shared_ptr[const CKeyValueMetadata] unwrap(self) nogil: + return self.wrapped + + def equals(self, KeyValueMetadata other): + """ + Parameters + ---------- + other : pyarrow.KeyValueMetadata + + Returns + ------- + bool + """ + return self.metadata.Equals(deref(other.wrapped)) + + def __repr__(self): + return str(self) + + def __str__(self): + return frombytes(self.metadata.ToString(), safe=True) + + def __eq__(self, other): + try: + return self.equals(other) + except TypeError: + pass + + if isinstance(other, Mapping): + try: + other = KeyValueMetadata(other) + return self.equals(other) + except TypeError: + pass + + return NotImplemented + + def __len__(self): + return self.metadata.size() + + def __contains__(self, key): + return self.metadata.Contains(tobytes(key)) + + def __getitem__(self, key): + return GetResultValue(self.metadata.Get(tobytes(key))) + + def __iter__(self): + return self.keys() + + def __reduce__(self): + return KeyValueMetadata, (list(self.items()),) + + def key(self, i): + """ + Parameters + ---------- + i : int + + Returns + ------- + byte + """ + return self.metadata.key(i) + + def value(self, i): + """ + Parameters + ---------- + i : int + + Returns + ------- + byte + """ + return self.metadata.value(i) + + def keys(self): + for i in range(self.metadata.size()): + yield self.metadata.key(i) + + def values(self): + for i in range(self.metadata.size()): + yield self.metadata.value(i) + + def items(self): + for i in range(self.metadata.size()): + yield (self.metadata.key(i), self.metadata.value(i)) + + def get_all(self, key): + """ + Parameters + ---------- + key : str + + Returns + ------- + list[byte] + """ + key = tobytes(key) + return [v for k, v in self.items() if k == key] + + def to_dict(self): + """ + Convert KeyValueMetadata to dict. If a key occurs twice, the value for + the first one is returned + """ + cdef object key # to force coercion to Python + result = ordered_dict() + for i in range(self.metadata.size()): + key = self.metadata.key(i) + if key not in result: + result[key] = self.metadata.value(i) + return result + + +cpdef KeyValueMetadata ensure_metadata(object meta, c_bool allow_none=False): + if allow_none and meta is None: + return None + elif isinstance(meta, KeyValueMetadata): + return meta + else: + return KeyValueMetadata(meta) + + +cdef class Field(_Weakrefable): + """ + A named field, with a data type, nullability, and optional metadata. + + Notes + ----- + Do not use this class's constructor directly; use pyarrow.field + + Examples + -------- + Create an instance of pyarrow.Field: + + >>> import pyarrow as pa + >>> pa.field('key', pa.int32()) + pyarrow.Field + >>> pa.field('key', pa.int32(), nullable=False) + pyarrow.Field + >>> field = pa.field('key', pa.int32(), + ... metadata={"key": "Something important"}) + >>> field + pyarrow.Field + >>> field.metadata + {b'key': b'Something important'} + + Use the field to create a struct type: + + >>> pa.struct([field]) + StructType(struct) + """ + + def __cinit__(self): + pass + + def __init__(self): + raise TypeError("Do not call Field's constructor directly, use " + "`pyarrow.field` instead.") + + cdef void init(self, const shared_ptr[CField]& field): + self.sp_field = field + self.field = field.get() + self.type = pyarrow_wrap_data_type(field.get().type()) + + def equals(self, Field other, bint check_metadata=False): + """ + Test if this field is equal to the other + + Parameters + ---------- + other : pyarrow.Field + check_metadata : bool, default False + Whether Field metadata equality should be checked as well. + + Returns + ------- + is_equal : bool + + Examples + -------- + >>> import pyarrow as pa + >>> f1 = pa.field('key', pa.int32()) + >>> f2 = pa.field('key', pa.int32(), nullable=False) + >>> f1.equals(f2) + False + >>> f1.equals(f1) + True + """ + return self.field.Equals(deref(other.field), check_metadata) + + def __eq__(self, other): + try: + return self.equals(other) + except TypeError: + return NotImplemented + + def __reduce__(self): + return field, (self.name, self.type, self.nullable, self.metadata) + + def __str__(self): + return 'pyarrow.Field<{0}>'.format( + frombytes(self.field.ToString(), safe=True)) + + def __repr__(self): + return self.__str__() + + def __hash__(self): + return hash((self.field.name(), self.type, self.field.nullable())) + + @property + def nullable(self): + """ + The field nullability. + + Examples + -------- + >>> import pyarrow as pa + >>> f1 = pa.field('key', pa.int32()) + >>> f2 = pa.field('key', pa.int32(), nullable=False) + >>> f1.nullable + True + >>> f2.nullable + False + """ + return self.field.nullable() + + @property + def name(self): + """ + The field name. + + Examples + -------- + >>> import pyarrow as pa + >>> field = pa.field('key', pa.int32()) + >>> field.name + 'key' + """ + return frombytes(self.field.name()) + + @property + def metadata(self): + """ + The field metadata (if any is set). + + Returns + ------- + metadata : dict or None + + Examples + -------- + >>> import pyarrow as pa + >>> field = pa.field('key', pa.int32(), + ... metadata={"key": "Something important"}) + >>> field.metadata + {b'key': b'Something important'} + """ + wrapped = pyarrow_wrap_metadata(self.field.metadata()) + if wrapped is not None: + return wrapped.to_dict() + else: + return wrapped + + def with_metadata(self, metadata): + """ + Add metadata as dict of string keys and values to Field + + Parameters + ---------- + metadata : dict + Keys and values must be string-like / coercible to bytes + + Returns + ------- + field : pyarrow.Field + + Examples + -------- + >>> import pyarrow as pa + >>> field = pa.field('key', pa.int32()) + + Create new field by adding metadata to existing one: + + >>> field_new = field.with_metadata({"key": "Something important"}) + >>> field_new + pyarrow.Field + >>> field_new.metadata + {b'key': b'Something important'} + """ + cdef shared_ptr[CField] c_field + + meta = ensure_metadata(metadata, allow_none=False) + with nogil: + c_field = self.field.WithMetadata(meta.unwrap()) + + return pyarrow_wrap_field(c_field) + + def remove_metadata(self): + """ + Create new field without metadata, if any + + Returns + ------- + field : pyarrow.Field + + Examples + -------- + >>> import pyarrow as pa + >>> field = pa.field('key', pa.int32(), + ... metadata={"key": "Something important"}) + >>> field.metadata + {b'key': b'Something important'} + + Create new field by removing the metadata from the existing one: + + >>> field_new = field.remove_metadata() + >>> field_new.metadata + """ + cdef shared_ptr[CField] new_field + with nogil: + new_field = self.field.RemoveMetadata() + return pyarrow_wrap_field(new_field) + + def with_type(self, DataType new_type): + """ + A copy of this field with the replaced type + + Parameters + ---------- + new_type : pyarrow.DataType + + Returns + ------- + field : pyarrow.Field + + Examples + -------- + >>> import pyarrow as pa + >>> field = pa.field('key', pa.int32()) + >>> field + pyarrow.Field + + Create new field by replacing type of an existing one: + + >>> field_new = field.with_type(pa.int64()) + >>> field_new + pyarrow.Field + """ + cdef: + shared_ptr[CField] c_field + shared_ptr[CDataType] c_datatype + + c_datatype = pyarrow_unwrap_data_type(new_type) + with nogil: + c_field = self.field.WithType(c_datatype) + + return pyarrow_wrap_field(c_field) + + def with_name(self, name): + """ + A copy of this field with the replaced name + + Parameters + ---------- + name : str + + Returns + ------- + field : pyarrow.Field + + Examples + -------- + >>> import pyarrow as pa + >>> field = pa.field('key', pa.int32()) + >>> field + pyarrow.Field + + Create new field by replacing the name of an existing one: + + >>> field_new = field.with_name('lock') + >>> field_new + pyarrow.Field + """ + cdef: + shared_ptr[CField] c_field + + c_field = self.field.WithName(tobytes(name)) + + return pyarrow_wrap_field(c_field) + + def with_nullable(self, nullable): + """ + A copy of this field with the replaced nullability + + Parameters + ---------- + nullable : bool + + Returns + ------- + field: pyarrow.Field + + Examples + -------- + >>> import pyarrow as pa + >>> field = pa.field('key', pa.int32()) + >>> field + pyarrow.Field + >>> field.nullable + True + + Create new field by replacing the nullability of an existing one: + + >>> field_new = field.with_nullable(False) + >>> field_new + pyarrow.Field + >>> field_new.nullable + False + """ + cdef: + shared_ptr[CField] field + c_bool c_nullable + + c_nullable = bool(nullable) + with nogil: + c_field = self.field.WithNullable(c_nullable) + + return pyarrow_wrap_field(c_field) + + def flatten(self): + """ + Flatten this field. If a struct field, individual child fields + will be returned with their names prefixed by the parent's name. + + Returns + ------- + fields : List[pyarrow.Field] + + Examples + -------- + >>> import pyarrow as pa + >>> f1 = pa.field('bar', pa.float64(), nullable=False) + >>> f2 = pa.field('foo', pa.int32()).with_metadata({"key": "Something important"}) + >>> ff = pa.field('ff', pa.struct([f1, f2]), nullable=False) + + Flatten a struct field: + + >>> ff + pyarrow.Field not null> + >>> ff.flatten() + [pyarrow.Field, pyarrow.Field] + """ + cdef vector[shared_ptr[CField]] flattened + with nogil: + flattened = self.field.Flatten() + return [pyarrow_wrap_field(f) for f in flattened] + + def _export_to_c(self, out_ptr): + """ + Export to a C ArrowSchema struct, given its pointer. + + Be careful: if you don't pass the ArrowSchema struct to a consumer, + its memory will leak. This is a low-level function intended for + expert users. + """ + check_status(ExportField(deref(self.field), + _as_c_pointer(out_ptr))) + + @staticmethod + def _import_from_c(in_ptr): + """ + Import Field from a C ArrowSchema struct, given its pointer. + + This is a low-level function intended for expert users. + """ + cdef void* c_ptr = _as_c_pointer(in_ptr) + with nogil: + result = GetResultValue(ImportField( c_ptr)) + return pyarrow_wrap_field(result) + + def __arrow_c_schema__(self): + """ + Export to a ArrowSchema PyCapsule + + Unlike _export_to_c, this will not leak memory if the capsule is not used. + """ + cdef ArrowSchema* c_schema + capsule = alloc_c_schema(&c_schema) + + with nogil: + check_status(ExportField(deref(self.field), c_schema)) + + return capsule + + @staticmethod + def _import_from_c_capsule(schema): + """ + Import a Field from a ArrowSchema PyCapsule + + Parameters + ---------- + schema : PyCapsule + A valid PyCapsule with name 'arrow_schema' containing an + ArrowSchema pointer. + """ + cdef: + ArrowSchema* c_schema + shared_ptr[CField] c_field + + if not PyCapsule_IsValid(schema, 'arrow_schema'): + raise ValueError( + "Not an ArrowSchema object" + ) + c_schema = PyCapsule_GetPointer(schema, 'arrow_schema') + + with nogil: + c_field = GetResultValue(ImportField(c_schema)) + + return pyarrow_wrap_field(c_field) + + +cdef class Schema(_Weakrefable): + """ + A named collection of types a.k.a schema. A schema defines the + column names and types in a record batch or table data structure. + They also contain metadata about the columns. For example, schemas + converted from Pandas contain metadata about their original Pandas + types so they can be converted back to the same types. + + Warnings + -------- + Do not call this class's constructor directly. Instead use + :func:`pyarrow.schema` factory function which makes a new Arrow + Schema object. + + Examples + -------- + Create a new Arrow Schema object: + + >>> import pyarrow as pa + >>> pa.schema([ + ... ('some_int', pa.int32()), + ... ('some_string', pa.string()) + ... ]) + some_int: int32 + some_string: string + + Create Arrow Schema with metadata: + + >>> pa.schema([ + ... pa.field('n_legs', pa.int64()), + ... pa.field('animals', pa.string())], + ... metadata={"n_legs": "Number of legs per animal"}) + n_legs: int64 + animals: string + -- schema metadata -- + n_legs: 'Number of legs per animal' + """ + + def __cinit__(self): + pass + + def __init__(self): + raise TypeError("Do not call Schema's constructor directly, use " + "`pyarrow.schema` instead.") + + def __len__(self): + return self.schema.num_fields() + + def __getitem__(self, key): + # access by integer index + return self._field(key) + + def __iter__(self): + for i in range(len(self)): + yield self[i] + + cdef void init(self, const vector[shared_ptr[CField]]& fields): + self.schema = new CSchema(fields) + self.sp_schema.reset(self.schema) + + cdef void init_schema(self, const shared_ptr[CSchema]& schema): + self.schema = schema.get() + self.sp_schema = schema + + def __reduce__(self): + return schema, (list(self), self.metadata) + + def __hash__(self): + return hash((tuple(self), self.metadata)) + + def __sizeof__(self): + size = 0 + if self.metadata: + for key, value in self.metadata.items(): + size += sys.getsizeof(key) + size += sys.getsizeof(value) + + return size + super(Schema, self).__sizeof__() + + @property + def pandas_metadata(self): + """ + Return deserialized-from-JSON pandas metadata field (if it exists) + + Examples + -------- + >>> import pyarrow as pa + >>> import pandas as pd + >>> df = pd.DataFrame({'n_legs': [2, 4, 5, 100], + ... 'animals': ["Flamingo", "Horse", "Brittle stars", "Centipede"]}) + >>> schema = pa.Table.from_pandas(df).schema + + Select pandas metadata field from Arrow Schema: + + >>> schema.pandas_metadata + {'index_columns': [{'kind': 'range', 'name': None, 'start': 0, 'stop': 4, 'step': 1}], ... + """ + metadata = self.metadata + key = b'pandas' + if metadata is None or key not in metadata: + return None + + import json + return json.loads(metadata[key].decode('utf8')) + + @property + def names(self): + """ + The schema's field names. + + Returns + ------- + list of str + + Examples + -------- + >>> import pyarrow as pa + >>> schema = pa.schema([ + ... pa.field('n_legs', pa.int64()), + ... pa.field('animals', pa.string())]) + + Get the names of the schema's fields: + + >>> schema.names + ['n_legs', 'animals'] + """ + cdef int i + result = [] + for i in range(self.schema.num_fields()): + name = frombytes(self.schema.field(i).get().name()) + result.append(name) + return result + + @property + def types(self): + """ + The schema's field types. + + Returns + ------- + list of DataType + + Examples + -------- + >>> import pyarrow as pa + >>> schema = pa.schema([ + ... pa.field('n_legs', pa.int64()), + ... pa.field('animals', pa.string())]) + + Get the types of the schema's fields: + + >>> schema.types + [DataType(int64), DataType(string)] + """ + return [field.type for field in self] + + @property + def metadata(self): + """ + The schema's metadata (if any is set). + + Returns + ------- + metadata: dict or None + + Examples + -------- + >>> import pyarrow as pa + >>> schema = pa.schema([ + ... pa.field('n_legs', pa.int64()), + ... pa.field('animals', pa.string())], + ... metadata={"n_legs": "Number of legs per animal"}) + + Get the metadata of the schema's fields: + + >>> schema.metadata + {b'n_legs': b'Number of legs per animal'} + """ + wrapped = pyarrow_wrap_metadata(self.schema.metadata()) + if wrapped is not None: + return wrapped.to_dict() + else: + return wrapped + + def __eq__(self, other): + try: + return self.equals(other) + except TypeError: + return NotImplemented + + def empty_table(self): + """ + Provide an empty table according to the schema. + + Returns + ------- + table: pyarrow.Table + + Examples + -------- + >>> import pyarrow as pa + >>> schema = pa.schema([ + ... pa.field('n_legs', pa.int64()), + ... pa.field('animals', pa.string())]) + + Create an empty table with schema's fields: + + >>> schema.empty_table() + pyarrow.Table + n_legs: int64 + animals: string + ---- + n_legs: [[]] + animals: [[]] + """ + arrays = [_empty_array(field.type) for field in self] + return Table.from_arrays(arrays, schema=self) + + def equals(self, Schema other not None, bint check_metadata=False): + """ + Test if this schema is equal to the other + + Parameters + ---------- + other : pyarrow.Schema + check_metadata : bool, default False + Key/value metadata must be equal too + + Returns + ------- + is_equal : bool + + Examples + -------- + >>> import pyarrow as pa + >>> schema1 = pa.schema([ + ... pa.field('n_legs', pa.int64()), + ... pa.field('animals', pa.string())], + ... metadata={"n_legs": "Number of legs per animal"}) + >>> schema2 = pa.schema([ + ... ('some_int', pa.int32()), + ... ('some_string', pa.string()) + ... ]) + + Test two equal schemas: + + >>> schema1.equals(schema1) + True + + Test two unequal schemas: + + >>> schema1.equals(schema2) + False + """ + return self.sp_schema.get().Equals(deref(other.schema), + check_metadata) + + @classmethod + def from_pandas(cls, df, preserve_index=None): + """ + Returns implied schema from dataframe + + Parameters + ---------- + df : pandas.DataFrame + preserve_index : bool, default True + Whether to store the index as an additional column (or columns, for + MultiIndex) in the resulting `Table`. + The default of None will store the index as a column, except for + RangeIndex which is stored as metadata only. Use + ``preserve_index=True`` to force it to be stored as a column. + + Returns + ------- + pyarrow.Schema + + Examples + -------- + >>> import pandas as pd + >>> import pyarrow as pa + >>> df = pd.DataFrame({ + ... 'int': [1, 2], + ... 'str': ['a', 'b'] + ... }) + + Create an Arrow Schema from the schema of a pandas dataframe: + + >>> pa.Schema.from_pandas(df) + int: int64 + str: string + -- schema metadata -- + pandas: '{"index_columns": [{"kind": "range", "name": null, ... + """ + from pyarrow.pandas_compat import dataframe_to_types + names, types, metadata = dataframe_to_types( + df, + preserve_index=preserve_index + ) + fields = [] + for name, type_ in zip(names, types): + fields.append(field(name, type_)) + return schema(fields, metadata) + + def field(self, i): + """ + Select a field by its column name or numeric index. + + Parameters + ---------- + i : int or string + + Returns + ------- + pyarrow.Field + + Examples + -------- + >>> import pyarrow as pa + >>> schema = pa.schema([ + ... pa.field('n_legs', pa.int64()), + ... pa.field('animals', pa.string())]) + + Select the second field: + + >>> schema.field(1) + pyarrow.Field + + Select the field of the column named 'n_legs': + + >>> schema.field('n_legs') + pyarrow.Field + """ + if isinstance(i, (bytes, str)): + field_index = self.get_field_index(i) + if field_index < 0: + raise KeyError("Column {} does not exist in schema".format(i)) + else: + return self._field(field_index) + elif isinstance(i, int): + return self._field(i) + else: + raise TypeError("Index must either be string or integer") + + def _field(self, int i): + """ + Select a field by its numeric index. + + Parameters + ---------- + i : int + + Returns + ------- + pyarrow.Field + """ + cdef int index = _normalize_index(i, self.schema.num_fields()) + return pyarrow_wrap_field(self.schema.field(index)) + + def field_by_name(self, name): + """ + DEPRECATED + + Parameters + ---------- + name : str + + Returns + ------- + field: pyarrow.Field + """ + cdef: + vector[shared_ptr[CField]] results + + warnings.warn( + "The 'field_by_name' method is deprecated, use 'field' instead", + FutureWarning, stacklevel=2) + + results = self.schema.GetAllFieldsByName(tobytes(name)) + if results.size() == 0: + return None + elif results.size() > 1: + warnings.warn("Schema field name corresponds to more " + "than one field", UserWarning) + return None + else: + return pyarrow_wrap_field(results[0]) + + def get_field_index(self, name): + """ + Return index of the unique field with the given name. + + Parameters + ---------- + name : str + The name of the field to look up. + + Returns + ------- + index : int + The index of the field with the given name; -1 if the + name isn't found or there are several fields with the given + name. + + Examples + -------- + >>> import pyarrow as pa + >>> schema = pa.schema([ + ... pa.field('n_legs', pa.int64()), + ... pa.field('animals', pa.string())]) + + Get the index of the field named 'animals': + + >>> schema.get_field_index("animals") + 1 + + Index in case of several fields with the given name: + + >>> schema = pa.schema([ + ... pa.field('n_legs', pa.int64()), + ... pa.field('animals', pa.string()), + ... pa.field('animals', pa.bool_())], + ... metadata={"n_legs": "Number of legs per animal"}) + >>> schema.get_field_index("animals") + -1 + """ + return self.schema.GetFieldIndex(tobytes(name)) + + def get_all_field_indices(self, name): + """ + Return sorted list of indices for the fields with the given name. + + Parameters + ---------- + name : str + The name of the field to look up. + + Returns + ------- + indices : List[int] + + Examples + -------- + >>> import pyarrow as pa + >>> schema = pa.schema([ + ... pa.field('n_legs', pa.int64()), + ... pa.field('animals', pa.string()), + ... pa.field('animals', pa.bool_())]) + + Get the indexes of the fields named 'animals': + + >>> schema.get_all_field_indices("animals") + [1, 2] + """ + return self.schema.GetAllFieldIndices(tobytes(name)) + + def append(self, Field field): + """ + Append a field at the end of the schema. + + In contrast to Python's ``list.append()`` it does return a new + object, leaving the original Schema unmodified. + + Parameters + ---------- + field : Field + + Returns + ------- + schema: Schema + New object with appended field. + + Examples + -------- + >>> import pyarrow as pa + >>> schema = pa.schema([ + ... pa.field('n_legs', pa.int64()), + ... pa.field('animals', pa.string())]) + + Append a field 'extra' at the end of the schema: + + >>> schema_new = schema.append(pa.field('extra', pa.bool_())) + >>> schema_new + n_legs: int64 + animals: string + extra: bool + + Original schema is unmodified: + + >>> schema + n_legs: int64 + animals: string + """ + return self.insert(self.schema.num_fields(), field) + + def insert(self, int i, Field field): + """ + Add a field at position i to the schema. + + Parameters + ---------- + i : int + field : Field + + Returns + ------- + schema: Schema + + Examples + -------- + >>> import pyarrow as pa + >>> schema = pa.schema([ + ... pa.field('n_legs', pa.int64()), + ... pa.field('animals', pa.string())]) + + Insert a new field on the second position: + + >>> schema.insert(1, pa.field('extra', pa.bool_())) + n_legs: int64 + extra: bool + animals: string + """ + cdef: + shared_ptr[CSchema] new_schema + shared_ptr[CField] c_field + + c_field = field.sp_field + + with nogil: + new_schema = GetResultValue(self.schema.AddField(i, c_field)) + + return pyarrow_wrap_schema(new_schema) + + def remove(self, int i): + """ + Remove the field at index i from the schema. + + Parameters + ---------- + i : int + + Returns + ------- + schema: Schema + + Examples + -------- + >>> import pyarrow as pa + >>> schema = pa.schema([ + ... pa.field('n_legs', pa.int64()), + ... pa.field('animals', pa.string())]) + + Remove the second field of the schema: + + >>> schema.remove(1) + n_legs: int64 + """ + cdef shared_ptr[CSchema] new_schema + + with nogil: + new_schema = GetResultValue(self.schema.RemoveField(i)) + + return pyarrow_wrap_schema(new_schema) + + def set(self, int i, Field field): + """ + Replace a field at position i in the schema. + + Parameters + ---------- + i : int + field : Field + + Returns + ------- + schema: Schema + + Examples + -------- + >>> import pyarrow as pa + >>> schema = pa.schema([ + ... pa.field('n_legs', pa.int64()), + ... pa.field('animals', pa.string())]) + + Replace the second field of the schema with a new field 'extra': + + >>> schema.set(1, pa.field('replaced', pa.bool_())) + n_legs: int64 + replaced: bool + """ + cdef: + shared_ptr[CSchema] new_schema + shared_ptr[CField] c_field + + c_field = field.sp_field + + with nogil: + new_schema = GetResultValue(self.schema.SetField(i, c_field)) + + return pyarrow_wrap_schema(new_schema) + + def add_metadata(self, metadata): + """ + DEPRECATED + + Parameters + ---------- + metadata : dict + Keys and values must be string-like / coercible to bytes + """ + warnings.warn("The 'add_metadata' method is deprecated, use " + "'with_metadata' instead", FutureWarning, stacklevel=2) + return self.with_metadata(metadata) + + def with_metadata(self, metadata): + """ + Add metadata as dict of string keys and values to Schema + + Parameters + ---------- + metadata : dict + Keys and values must be string-like / coercible to bytes + + Returns + ------- + schema : pyarrow.Schema + + Examples + -------- + >>> import pyarrow as pa + >>> schema = pa.schema([ + ... pa.field('n_legs', pa.int64()), + ... pa.field('animals', pa.string())]) + + Add metadata to existing schema field: + + >>> schema.with_metadata({"n_legs": "Number of legs per animal"}) + n_legs: int64 + animals: string + -- schema metadata -- + n_legs: 'Number of legs per animal' + """ + cdef shared_ptr[CSchema] c_schema + + meta = ensure_metadata(metadata, allow_none=False) + with nogil: + c_schema = self.schema.WithMetadata(meta.unwrap()) + + return pyarrow_wrap_schema(c_schema) + + def serialize(self, memory_pool=None): + """ + Write Schema to Buffer as encapsulated IPC message + + Parameters + ---------- + memory_pool : MemoryPool, default None + Uses default memory pool if not specified + + Returns + ------- + serialized : Buffer + + Examples + -------- + >>> import pyarrow as pa + >>> schema = pa.schema([ + ... pa.field('n_legs', pa.int64()), + ... pa.field('animals', pa.string())]) + + Write schema to Buffer: + + >>> schema.serialize() + + """ + cdef: + shared_ptr[CBuffer] buffer + CMemoryPool* pool = maybe_unbox_memory_pool(memory_pool) + + with nogil: + buffer = GetResultValue(SerializeSchema(deref(self.schema), + pool)) + return pyarrow_wrap_buffer(buffer) + + def remove_metadata(self): + """ + Create new schema without metadata, if any + + Returns + ------- + schema : pyarrow.Schema + + Examples + -------- + >>> import pyarrow as pa + >>> schema = pa.schema([ + ... pa.field('n_legs', pa.int64()), + ... pa.field('animals', pa.string())], + ... metadata={"n_legs": "Number of legs per animal"}) + >>> schema + n_legs: int64 + animals: string + -- schema metadata -- + n_legs: 'Number of legs per animal' + + Create a new schema with removing the metadata from the original: + + >>> schema.remove_metadata() + n_legs: int64 + animals: string + """ + cdef shared_ptr[CSchema] new_schema + with nogil: + new_schema = self.schema.RemoveMetadata() + return pyarrow_wrap_schema(new_schema) + + def to_string(self, truncate_metadata=True, show_field_metadata=True, + show_schema_metadata=True): + """ + Return human-readable representation of Schema + + Parameters + ---------- + truncate_metadata : boolean, default True + Limit metadata key/value display to a single line of ~80 characters + or less + show_field_metadata : boolean, default True + Display Field-level KeyValueMetadata + show_schema_metadata : boolean, default True + Display Schema-level KeyValueMetadata + + Returns + ------- + str : the formatted output + """ + cdef: + c_string result + PrettyPrintOptions options = PrettyPrintOptions.Defaults() + + options.indent = 0 + options.truncate_metadata = truncate_metadata + options.show_field_metadata = show_field_metadata + options.show_schema_metadata = show_schema_metadata + + with nogil: + check_status( + PrettyPrint( + deref(self.schema), + options, + &result + ) + ) + + return frombytes(result, safe=True) + + def _export_to_c(self, out_ptr): + """ + Export to a C ArrowSchema struct, given its pointer. + + Be careful: if you don't pass the ArrowSchema struct to a consumer, + its memory will leak. This is a low-level function intended for + expert users. + """ + check_status(ExportSchema(deref(self.schema), + _as_c_pointer(out_ptr))) + + @staticmethod + def _import_from_c(in_ptr): + """ + Import Schema from a C ArrowSchema struct, given its pointer. + + This is a low-level function intended for expert users. + """ + cdef void* c_ptr = _as_c_pointer(in_ptr) + with nogil: + result = GetResultValue(ImportSchema( c_ptr)) + return pyarrow_wrap_schema(result) + + def __str__(self): + return self.to_string() + + def __repr__(self): + return self.__str__() + + def __arrow_c_schema__(self): + """ + Export to a ArrowSchema PyCapsule + + Unlike _export_to_c, this will not leak memory if the capsule is not used. + """ + cdef ArrowSchema* c_schema + capsule = alloc_c_schema(&c_schema) + + with nogil: + check_status(ExportSchema(deref(self.schema), c_schema)) + + return capsule + + @staticmethod + def _import_from_c_capsule(schema): + """ + Import a Schema from a ArrowSchema PyCapsule + + Parameters + ---------- + schema : PyCapsule + A valid PyCapsule with name 'arrow_schema' containing an + ArrowSchema pointer. + """ + cdef: + ArrowSchema* c_schema + + if not PyCapsule_IsValid(schema, 'arrow_schema'): + raise ValueError( + "Not an ArrowSchema object" + ) + c_schema = PyCapsule_GetPointer(schema, 'arrow_schema') + + with nogil: + result = GetResultValue(ImportSchema(c_schema)) + + return pyarrow_wrap_schema(result) + + +def unify_schemas(schemas, *, promote_options="default"): + """ + Unify schemas by merging fields by name. + + The resulting schema will contain the union of fields from all schemas. + Fields with the same name will be merged. Note that two fields with + different types will fail merging by default. + + - The unified field will inherit the metadata from the schema where + that field is first defined. + - The first N fields in the schema will be ordered the same as the + N fields in the first schema. + + The resulting schema will inherit its metadata from the first input + schema. + + Parameters + ---------- + schemas : list of Schema + Schemas to merge into a single one. + promote_options : str, default default + Accepts strings "default" and "permissive". + Default: null and only null can be unified with another type. + Permissive: types are promoted to the greater common denominator. + + Returns + ------- + Schema + + Raises + ------ + ArrowInvalid : + If any input schema contains fields with duplicate names. + If Fields of the same name are not mergeable. + """ + cdef: + Schema schema + CField.CMergeOptions c_options + vector[shared_ptr[CSchema]] c_schemas + for schema in schemas: + if not isinstance(schema, Schema): + raise TypeError("Expected Schema, got {}".format(type(schema))) + c_schemas.push_back(pyarrow_unwrap_schema(schema)) + + if promote_options == "default": + c_options = CField.CMergeOptions.Defaults() + elif promote_options == "permissive": + c_options = CField.CMergeOptions.Permissive() + else: + raise ValueError(f"Invalid merge mode: {promote_options}") + + return pyarrow_wrap_schema( + GetResultValue(UnifySchemas(c_schemas, c_options))) + + +cdef dict _type_cache = {} + + +cdef DataType primitive_type(Type type): + if type in _type_cache: + return _type_cache[type] + + cdef DataType out = DataType.__new__(DataType) + out.init(GetPrimitiveType(type)) + + _type_cache[type] = out + return out + + +# ----------------------------------------------------------- +# Type factory functions + + +def field(name, type=None, nullable=None, metadata=None): + """ + Create a pyarrow.Field instance. + + Parameters + ---------- + name : str or bytes + Name of the field. + Alternatively, you can also pass an object that implements the Arrow + PyCapsule Protocol for schemas (has an ``__arrow_c_schema__`` method). + type : pyarrow.DataType or str + Arrow datatype of the field or a string matching one. + nullable : bool, default True + Whether the field's values are nullable. + metadata : dict, default None + Optional field metadata, the keys and values must be coercible to + bytes. + + Returns + ------- + field : pyarrow.Field + + Examples + -------- + Create an instance of pyarrow.Field: + + >>> import pyarrow as pa + >>> pa.field('key', pa.int32()) + pyarrow.Field + >>> pa.field('key', pa.int32(), nullable=False) + pyarrow.Field + + >>> field = pa.field('key', pa.int32(), + ... metadata={"key": "Something important"}) + >>> field + pyarrow.Field + >>> field.metadata + {b'key': b'Something important'} + + Use the field to create a struct type: + + >>> pa.struct([field]) + StructType(struct) + + A str can also be passed for the type parameter: + + >>> pa.field('key', 'int32') + pyarrow.Field + """ + if hasattr(name, "__arrow_c_schema__"): + if type is not None: + raise ValueError( + "cannot specify 'type' when creating a Field from an ArrowSchema" + ) + field = Field._import_from_c_capsule(name.__arrow_c_schema__()) + if metadata is not None: + field = field.with_metadata(metadata) + if nullable is not None: + field = field.with_nullable(nullable) + return field + + cdef: + Field result = Field.__new__(Field) + DataType _type = ensure_type(type, allow_none=False) + shared_ptr[const CKeyValueMetadata] c_meta + + nullable = True if nullable is None else nullable + + metadata = ensure_metadata(metadata, allow_none=True) + c_meta = pyarrow_unwrap_metadata(metadata) + + if _type.type.id() == _Type_NA and not nullable: + raise ValueError("A null type field may not be non-nullable") + + result.sp_field.reset( + new CField(tobytes(name), _type.sp_type, nullable, c_meta) + ) + result.field = result.sp_field.get() + result.type = _type + + return result + + +cdef set PRIMITIVE_TYPES = set([ + _Type_NA, _Type_BOOL, + _Type_UINT8, _Type_INT8, + _Type_UINT16, _Type_INT16, + _Type_UINT32, _Type_INT32, + _Type_UINT64, _Type_INT64, + _Type_TIMESTAMP, _Type_DATE32, + _Type_TIME32, _Type_TIME64, + _Type_DATE64, + _Type_HALF_FLOAT, + _Type_FLOAT, + _Type_DOUBLE]) + + +def null(): + """ + Create instance of null type. + + Examples + -------- + Create an instance of a null type: + + >>> import pyarrow as pa + >>> pa.null() + DataType(null) + >>> print(pa.null()) + null + + Create a ``Field`` type with a null type and a name: + + >>> pa.field('null_field', pa.null()) + pyarrow.Field + """ + return primitive_type(_Type_NA) + + +def bool_(): + """ + Create instance of boolean type. + + Examples + -------- + Create an instance of a boolean type: + + >>> import pyarrow as pa + >>> pa.bool_() + DataType(bool) + >>> print(pa.bool_()) + bool + + Create a ``Field`` type with a boolean type + and a name: + + >>> pa.field('bool_field', pa.bool_()) + pyarrow.Field + """ + return primitive_type(_Type_BOOL) + + +def uint8(): + """ + Create instance of unsigned int8 type. + + Examples + -------- + Create an instance of unsigned int8 type: + + >>> import pyarrow as pa + >>> pa.uint8() + DataType(uint8) + >>> print(pa.uint8()) + uint8 + + Create an array with unsigned int8 type: + + >>> pa.array([0, 1, 2], type=pa.uint8()) + + [ + 0, + 1, + 2 + ] + """ + return primitive_type(_Type_UINT8) + + +def int8(): + """ + Create instance of signed int8 type. + + Examples + -------- + Create an instance of int8 type: + + >>> import pyarrow as pa + >>> pa.int8() + DataType(int8) + >>> print(pa.int8()) + int8 + + Create an array with int8 type: + + >>> pa.array([0, 1, 2], type=pa.int8()) + + [ + 0, + 1, + 2 + ] + """ + return primitive_type(_Type_INT8) + + +def uint16(): + """ + Create instance of unsigned uint16 type. + + Examples + -------- + Create an instance of unsigned int16 type: + + >>> import pyarrow as pa + >>> pa.uint16() + DataType(uint16) + >>> print(pa.uint16()) + uint16 + + Create an array with unsigned int16 type: + + >>> pa.array([0, 1, 2], type=pa.uint16()) + + [ + 0, + 1, + 2 + ] + """ + return primitive_type(_Type_UINT16) + + +def int16(): + """ + Create instance of signed int16 type. + + Examples + -------- + Create an instance of int16 type: + + >>> import pyarrow as pa + >>> pa.int16() + DataType(int16) + >>> print(pa.int16()) + int16 + + Create an array with int16 type: + + >>> pa.array([0, 1, 2], type=pa.int16()) + + [ + 0, + 1, + 2 + ] + """ + return primitive_type(_Type_INT16) + + +def uint32(): + """ + Create instance of unsigned uint32 type. + + Examples + -------- + Create an instance of unsigned int32 type: + + >>> import pyarrow as pa + >>> pa.uint32() + DataType(uint32) + >>> print(pa.uint32()) + uint32 + + Create an array with unsigned int32 type: + + >>> pa.array([0, 1, 2], type=pa.uint32()) + + [ + 0, + 1, + 2 + ] + """ + return primitive_type(_Type_UINT32) + + +def int32(): + """ + Create instance of signed int32 type. + + Examples + -------- + Create an instance of int32 type: + + >>> import pyarrow as pa + >>> pa.int32() + DataType(int32) + >>> print(pa.int32()) + int32 + + Create an array with int32 type: + + >>> pa.array([0, 1, 2], type=pa.int32()) + + [ + 0, + 1, + 2 + ] + """ + return primitive_type(_Type_INT32) + + +def uint64(): + """ + Create instance of unsigned uint64 type. + + Examples + -------- + Create an instance of unsigned int64 type: + + >>> import pyarrow as pa + >>> pa.uint64() + DataType(uint64) + >>> print(pa.uint64()) + uint64 + + Create an array with unsigned uint64 type: + + >>> pa.array([0, 1, 2], type=pa.uint64()) + + [ + 0, + 1, + 2 + ] + """ + return primitive_type(_Type_UINT64) + + +def int64(): + """ + Create instance of signed int64 type. + + Examples + -------- + Create an instance of int64 type: + + >>> import pyarrow as pa + >>> pa.int64() + DataType(int64) + >>> print(pa.int64()) + int64 + + Create an array with int64 type: + + >>> pa.array([0, 1, 2], type=pa.int64()) + + [ + 0, + 1, + 2 + ] + """ + return primitive_type(_Type_INT64) + + +cdef dict _timestamp_type_cache = {} +cdef dict _time_type_cache = {} +cdef dict _duration_type_cache = {} + + +cdef timeunit_to_string(TimeUnit unit): + if unit == TimeUnit_SECOND: + return 's' + elif unit == TimeUnit_MILLI: + return 'ms' + elif unit == TimeUnit_MICRO: + return 'us' + elif unit == TimeUnit_NANO: + return 'ns' + + +cdef TimeUnit string_to_timeunit(unit) except *: + if unit == 's': + return TimeUnit_SECOND + elif unit == 'ms': + return TimeUnit_MILLI + elif unit == 'us': + return TimeUnit_MICRO + elif unit == 'ns': + return TimeUnit_NANO + else: + raise ValueError(f"Invalid time unit: {unit!r}") + + +def tzinfo_to_string(tz): + """ + Converts a time zone object into a string indicating the name of a time + zone, one of: + * As used in the Olson time zone database (the "tz database" or + "tzdata"), such as "America/New_York" + * An absolute time zone offset of the form +XX:XX or -XX:XX, such as +07:30 + + Parameters + ---------- + tz : datetime.tzinfo + Time zone object + + Returns + ------- + name : str + Time zone name + """ + return frombytes(GetResultValue(TzinfoToString(tz))) + + +def string_to_tzinfo(name): + """ + Convert a time zone name into a time zone object. + + Supported input strings are: + * As used in the Olson time zone database (the "tz database" or + "tzdata"), such as "America/New_York" + * An absolute time zone offset of the form +XX:XX or -XX:XX, such as +07:30 + + Parameters + ---------- + name: str + Time zone name. + + Returns + ------- + tz : datetime.tzinfo + Time zone object + """ + cdef PyObject* tz = GetResultValue(StringToTzinfo(name.encode('utf-8'))) + return PyObject_to_object(tz) + + +def timestamp(unit, tz=None): + """ + Create instance of timestamp type with resolution and optional time zone. + + Parameters + ---------- + unit : str + one of 's' [second], 'ms' [millisecond], 'us' [microsecond], or 'ns' + [nanosecond] + tz : str, default None + Time zone name. None indicates time zone naive + + Examples + -------- + Create an instance of timestamp type: + + >>> import pyarrow as pa + >>> pa.timestamp('us') + TimestampType(timestamp[us]) + >>> pa.timestamp('s', tz='America/New_York') + TimestampType(timestamp[s, tz=America/New_York]) + >>> pa.timestamp('s', tz='+07:30') + TimestampType(timestamp[s, tz=+07:30]) + + Use timestamp type when creating a scalar object: + + >>> from datetime import datetime + >>> pa.scalar(datetime(2012, 1, 1), type=pa.timestamp('s', tz='UTC')) + + >>> pa.scalar(datetime(2012, 1, 1), type=pa.timestamp('us')) + + + Returns + ------- + timestamp_type : TimestampType + """ + cdef: + TimeUnit unit_code + c_string c_timezone + + unit_code = string_to_timeunit(unit) + + cdef TimestampType out = TimestampType.__new__(TimestampType) + + if tz is None: + out.init(ctimestamp(unit_code)) + if unit_code in _timestamp_type_cache: + return _timestamp_type_cache[unit_code] + _timestamp_type_cache[unit_code] = out + else: + if not isinstance(tz, (bytes, str)): + tz = tzinfo_to_string(tz) + + c_timezone = tobytes(tz) + out.init(ctimestamp(unit_code, c_timezone)) + + return out + + +def time32(unit): + """ + Create instance of 32-bit time (time of day) type with unit resolution. + + Parameters + ---------- + unit : str + one of 's' [second], or 'ms' [millisecond] + + Returns + ------- + type : pyarrow.Time32Type + + Examples + -------- + >>> import pyarrow as pa + >>> pa.time32('s') + Time32Type(time32[s]) + >>> pa.time32('ms') + Time32Type(time32[ms]) + """ + cdef: + TimeUnit unit_code + c_string c_timezone + + if unit == 's': + unit_code = TimeUnit_SECOND + elif unit == 'ms': + unit_code = TimeUnit_MILLI + else: + raise ValueError(f"Invalid time unit for time32: {unit!r}") + + if unit_code in _time_type_cache: + return _time_type_cache[unit_code] + + cdef Time32Type out = Time32Type.__new__(Time32Type) + + out.init(ctime32(unit_code)) + _time_type_cache[unit_code] = out + + return out + + +def time64(unit): + """ + Create instance of 64-bit time (time of day) type with unit resolution. + + Parameters + ---------- + unit : str + One of 'us' [microsecond], or 'ns' [nanosecond]. + + Returns + ------- + type : pyarrow.Time64Type + + Examples + -------- + >>> import pyarrow as pa + >>> pa.time64('us') + Time64Type(time64[us]) + >>> pa.time64('ns') + Time64Type(time64[ns]) + """ + cdef: + TimeUnit unit_code + c_string c_timezone + + if unit == 'us': + unit_code = TimeUnit_MICRO + elif unit == 'ns': + unit_code = TimeUnit_NANO + else: + raise ValueError(f"Invalid time unit for time64: {unit!r}") + + if unit_code in _time_type_cache: + return _time_type_cache[unit_code] + + cdef Time64Type out = Time64Type.__new__(Time64Type) + + out.init(ctime64(unit_code)) + _time_type_cache[unit_code] = out + + return out + + +def duration(unit): + """ + Create instance of a duration type with unit resolution. + + Parameters + ---------- + unit : str + One of 's' [second], 'ms' [millisecond], 'us' [microsecond], or + 'ns' [nanosecond]. + + Returns + ------- + type : pyarrow.DurationType + + Examples + -------- + Create an instance of duration type: + + >>> import pyarrow as pa + >>> pa.duration('us') + DurationType(duration[us]) + >>> pa.duration('s') + DurationType(duration[s]) + + Create an array with duration type: + + >>> pa.array([0, 1, 2], type=pa.duration('s')) + + [ + 0, + 1, + 2 + ] + """ + cdef: + TimeUnit unit_code + + unit_code = string_to_timeunit(unit) + + if unit_code in _duration_type_cache: + return _duration_type_cache[unit_code] + + cdef DurationType out = DurationType.__new__(DurationType) + + out.init(cduration(unit_code)) + _duration_type_cache[unit_code] = out + + return out + + +def month_day_nano_interval(): + """ + Create instance of an interval type representing months, days and + nanoseconds between two dates. + + Examples + -------- + Create an instance of an month_day_nano_interval type: + + >>> import pyarrow as pa + >>> pa.month_day_nano_interval() + DataType(month_day_nano_interval) + + Create a scalar with month_day_nano_interval type: + + >>> pa.scalar((1, 15, -30), type=pa.month_day_nano_interval()) + + """ + return primitive_type(_Type_INTERVAL_MONTH_DAY_NANO) + + +def date32(): + """ + Create instance of 32-bit date (days since UNIX epoch 1970-01-01). + + Examples + -------- + Create an instance of 32-bit date type: + + >>> import pyarrow as pa + >>> pa.date32() + DataType(date32[day]) + + Create a scalar with 32-bit date type: + + >>> from datetime import date + >>> pa.scalar(date(2012, 1, 1), type=pa.date32()) + + """ + return primitive_type(_Type_DATE32) + + +def date64(): + """ + Create instance of 64-bit date (milliseconds since UNIX epoch 1970-01-01). + + Examples + -------- + Create an instance of 64-bit date type: + + >>> import pyarrow as pa + >>> pa.date64() + DataType(date64[ms]) + + Create a scalar with 64-bit date type: + + >>> from datetime import datetime + >>> pa.scalar(datetime(2012, 1, 1), type=pa.date64()) + + """ + return primitive_type(_Type_DATE64) + + +def float16(): + """ + Create half-precision floating point type. + + Examples + -------- + Create an instance of float16 type: + + >>> import pyarrow as pa + >>> pa.float16() + DataType(halffloat) + >>> print(pa.float16()) + halffloat + + Create an array with float16 type: + + >>> arr = np.array([1.5, np.nan], dtype=np.float16) + >>> a = pa.array(arr, type=pa.float16()) + >>> a + + [ + 15872, + 32256 + ] + + Note that unlike other float types, if you convert this array + to a python list, the types of its elements will be ``np.float16`` + + >>> [type(val) for val in a.to_pylist()] + [, ] + """ + return primitive_type(_Type_HALF_FLOAT) + + +def float32(): + """ + Create single-precision floating point type. + + Examples + -------- + Create an instance of float32 type: + + >>> import pyarrow as pa + >>> pa.float32() + DataType(float) + >>> print(pa.float32()) + float + + Create an array with float32 type: + + >>> pa.array([0.0, 1.0, 2.0], type=pa.float32()) + + [ + 0, + 1, + 2 + ] + """ + return primitive_type(_Type_FLOAT) + + +def float64(): + """ + Create double-precision floating point type. + + Examples + -------- + Create an instance of float64 type: + + >>> import pyarrow as pa + >>> pa.float64() + DataType(double) + >>> print(pa.float64()) + double + + Create an array with float64 type: + + >>> pa.array([0.0, 1.0, 2.0], type=pa.float64()) + + [ + 0, + 1, + 2 + ] + """ + return primitive_type(_Type_DOUBLE) + + +cpdef DataType decimal32(int precision, int scale=0): + """ + Create decimal type with precision and scale and 32-bit width. + + Arrow decimals are fixed-point decimal numbers encoded as a scaled + integer. The precision is the number of significant digits that the + decimal type can represent; the scale is the number of digits after + the decimal point (note the scale can be negative). + + As an example, ``decimal32(7, 3)`` can exactly represent the numbers + 1234.567 and -1234.567 (encoded internally as the 32-bit integers + 1234567 and -1234567, respectively), but neither 12345.67 nor 123.4567. + + ``decimal32(5, -3)`` can exactly represent the number 12345000 + (encoded internally as the 32-bit integer 12345), but neither + 123450000 nor 1234500. + + If you need a precision higher than 9 significant digits, consider + using ``decimal64``, ``decimal128``, or ``decimal256``. + + Parameters + ---------- + precision : int + Must be between 1 and 9 + scale : int + + Returns + ------- + decimal_type : Decimal32Type + + Examples + -------- + Create an instance of decimal type: + + >>> import pyarrow as pa + >>> pa.decimal32(5, 2) + Decimal32Type(decimal32(5, 2)) + + Create an array with decimal type: + + >>> import decimal + >>> a = decimal.Decimal('123.45') + >>> pa.array([a], pa.decimal32(5, 2)) + + [ + 123.45 + ] + """ + cdef shared_ptr[CDataType] decimal_type + if precision < 1 or precision > 9: + raise ValueError("precision should be between 1 and 9") + decimal_type.reset(new CDecimal32Type(precision, scale)) + return pyarrow_wrap_data_type(decimal_type) + + +cpdef DataType decimal64(int precision, int scale=0): + """ + Create decimal type with precision and scale and 64-bit width. + + Arrow decimals are fixed-point decimal numbers encoded as a scaled + integer. The precision is the number of significant digits that the + decimal type can represent; the scale is the number of digits after + the decimal point (note the scale can be negative). + + As an example, ``decimal64(7, 3)`` can exactly represent the numbers + 1234.567 and -1234.567 (encoded internally as the 64-bit integers + 1234567 and -1234567, respectively), but neither 12345.67 nor 123.4567. + + ``decimal64(5, -3)`` can exactly represent the number 12345000 + (encoded internally as the 64-bit integer 12345), but neither + 123450000 nor 1234500. + + If you need a precision higher than 18 significant digits, consider + using ``decimal128``, or ``decimal256``. + + Parameters + ---------- + precision : int + Must be between 1 and 18 + scale : int + + Returns + ------- + decimal_type : Decimal64Type + + Examples + -------- + Create an instance of decimal type: + + >>> import pyarrow as pa + >>> pa.decimal64(5, 2) + Decimal64Type(decimal64(5, 2)) + + Create an array with decimal type: + + >>> import decimal + >>> a = decimal.Decimal('123.45') + >>> pa.array([a], pa.decimal64(5, 2)) + + [ + 123.45 + ] + """ + cdef shared_ptr[CDataType] decimal_type + if precision < 1 or precision > 18: + raise ValueError("precision should be between 1 and 18") + decimal_type.reset(new CDecimal64Type(precision, scale)) + return pyarrow_wrap_data_type(decimal_type) + + +cpdef DataType decimal128(int precision, int scale=0): + """ + Create decimal type with precision and scale and 128-bit width. + + Arrow decimals are fixed-point decimal numbers encoded as a scaled + integer. The precision is the number of significant digits that the + decimal type can represent; the scale is the number of digits after + the decimal point (note the scale can be negative). + + As an example, ``decimal128(7, 3)`` can exactly represent the numbers + 1234.567 and -1234.567 (encoded internally as the 128-bit integers + 1234567 and -1234567, respectively), but neither 12345.67 nor 123.4567. + + ``decimal128(5, -3)`` can exactly represent the number 12345000 + (encoded internally as the 128-bit integer 12345), but neither + 123450000 nor 1234500. + + If you need a precision higher than 38 significant digits, consider + using ``decimal256``. + + Parameters + ---------- + precision : int + Must be between 1 and 38 + scale : int + + Returns + ------- + decimal_type : Decimal128Type + + Examples + -------- + Create an instance of decimal type: + + >>> import pyarrow as pa + >>> pa.decimal128(5, 2) + Decimal128Type(decimal128(5, 2)) + + Create an array with decimal type: + + >>> import decimal + >>> a = decimal.Decimal('123.45') + >>> pa.array([a], pa.decimal128(5, 2)) + + [ + 123.45 + ] + """ + cdef shared_ptr[CDataType] decimal_type + if precision < 1 or precision > 38: + raise ValueError("precision should be between 1 and 38") + decimal_type.reset(new CDecimal128Type(precision, scale)) + return pyarrow_wrap_data_type(decimal_type) + + +cpdef DataType decimal256(int precision, int scale=0): + """ + Create decimal type with precision and scale and 256-bit width. + + Arrow decimals are fixed-point decimal numbers encoded as a scaled + integer. The precision is the number of significant digits that the + decimal type can represent; the scale is the number of digits after + the decimal point (note the scale can be negative). + + For most use cases, the maximum precision offered by ``decimal128`` + is sufficient, and it will result in a more compact and more efficient + encoding. ``decimal256`` is useful if you need a precision higher + than 38 significant digits. + + Parameters + ---------- + precision : int + Must be between 1 and 76 + scale : int + + Returns + ------- + decimal_type : Decimal256Type + """ + cdef shared_ptr[CDataType] decimal_type + if precision < 1 or precision > 76: + raise ValueError("precision should be between 1 and 76") + decimal_type.reset(new CDecimal256Type(precision, scale)) + return pyarrow_wrap_data_type(decimal_type) + + +def string(): + """ + Create UTF8 variable-length string type. + + Examples + -------- + Create an instance of a string type: + + >>> import pyarrow as pa + >>> pa.string() + DataType(string) + + and use the string type to create an array: + + >>> pa.array(['foo', 'bar', 'baz'], type=pa.string()) + + [ + "foo", + "bar", + "baz" + ] + """ + return primitive_type(_Type_STRING) + + +def utf8(): + """ + Alias for string(). + + Examples + -------- + Create an instance of a string type: + + >>> import pyarrow as pa + >>> pa.utf8() + DataType(string) + + and use the string type to create an array: + + >>> pa.array(['foo', 'bar', 'baz'], type=pa.utf8()) + + [ + "foo", + "bar", + "baz" + ] + """ + return string() + + +def binary(int length=-1): + """ + Create variable-length or fixed size binary type. + + Parameters + ---------- + length : int, optional, default -1 + If length == -1 then return a variable length binary type. If length is + greater than or equal to 0 then return a fixed size binary type of + width `length`. + + Examples + -------- + Create an instance of a variable-length binary type: + + >>> import pyarrow as pa + >>> pa.binary() + DataType(binary) + + and use the variable-length binary type to create an array: + + >>> pa.array(['foo', 'bar', 'baz'], type=pa.binary()) + + [ + 666F6F, + 626172, + 62617A + ] + + Create an instance of a fixed-size binary type: + + >>> pa.binary(3) + FixedSizeBinaryType(fixed_size_binary[3]) + + and use the fixed-length binary type to create an array: + + >>> pa.array(['foo', 'bar', 'baz'], type=pa.binary(3)) + + [ + 666F6F, + 626172, + 62617A + ] + """ + if length == -1: + return primitive_type(_Type_BINARY) + + cdef shared_ptr[CDataType] fixed_size_binary_type + fixed_size_binary_type.reset(new CFixedSizeBinaryType(length)) + return pyarrow_wrap_data_type(fixed_size_binary_type) + + +def large_binary(): + """ + Create large variable-length binary type. + + This data type may not be supported by all Arrow implementations. Unless + you need to represent data larger than 2GB, you should prefer binary(). + + Examples + -------- + Create an instance of large variable-length binary type: + + >>> import pyarrow as pa + >>> pa.large_binary() + DataType(large_binary) + + and use the type to create an array: + + >>> pa.array(['foo', 'bar', 'baz'], type=pa.large_binary()) + + [ + 666F6F, + 626172, + 62617A + ] + """ + return primitive_type(_Type_LARGE_BINARY) + + +def large_string(): + """ + Create large UTF8 variable-length string type. + + This data type may not be supported by all Arrow implementations. Unless + you need to represent data larger than 2GB, you should prefer string(). + + Examples + -------- + Create an instance of large UTF8 variable-length binary type: + + >>> import pyarrow as pa + >>> pa.large_string() + DataType(large_string) + + and use the type to create an array: + + >>> pa.array(['foo', 'bar'] * 50, type=pa.large_string()) + + [ + "foo", + "bar", + ... + "foo", + "bar" + ] + """ + return primitive_type(_Type_LARGE_STRING) + + +def large_utf8(): + """ + Alias for large_string(). + + Examples + -------- + Create an instance of large UTF8 variable-length binary type: + + >>> import pyarrow as pa + >>> pa.large_utf8() + DataType(large_string) + + and use the type to create an array: + + >>> pa.array(['foo', 'bar'] * 50, type=pa.large_utf8()) + + [ + "foo", + "bar", + ... + "foo", + "bar" + ] + """ + return large_string() + + +def binary_view(): + """ + Create a variable-length binary view type. + + Examples + -------- + Create an instance of a string type: + + >>> import pyarrow as pa + >>> pa.binary_view() + DataType(binary_view) + """ + return primitive_type(_Type_BINARY_VIEW) + + +def string_view(): + """ + Create UTF8 variable-length string view type. + + Examples + -------- + Create an instance of a string type: + + >>> import pyarrow as pa + >>> pa.string_view() + DataType(string_view) + """ + return primitive_type(_Type_STRING_VIEW) + + +def list_(value_type, int list_size=-1): + """ + Create ListType instance from child data type or field. + + Parameters + ---------- + value_type : DataType or Field + list_size : int, optional, default -1 + If length == -1 then return a variable length list type. If length is + greater than or equal to 0 then return a fixed size list type. + + Returns + ------- + list_type : DataType + + Examples + -------- + Create an instance of ListType: + + >>> import pyarrow as pa + >>> pa.list_(pa.string()) + ListType(list) + >>> pa.list_(pa.int32(), 2) + FixedSizeListType(fixed_size_list[2]) + + Use the ListType to create a scalar: + + >>> pa.scalar(['foo', None], type=pa.list_(pa.string(), 2)) + + + or an array: + + >>> pa.array([[1, 2], [3, 4]], pa.list_(pa.int32(), 2)) + + [ + [ + 1, + 2 + ], + [ + 3, + 4 + ] + ] + """ + cdef: + Field _field + shared_ptr[CDataType] list_type + + if isinstance(value_type, DataType): + _field = field('item', value_type) + elif isinstance(value_type, Field): + _field = value_type + else: + raise TypeError('List requires DataType or Field') + + if list_size == -1: + list_type.reset(new CListType(_field.sp_field)) + else: + if list_size < 0: + raise ValueError("list_size should be a positive integer") + list_type.reset(new CFixedSizeListType(_field.sp_field, list_size)) + + return pyarrow_wrap_data_type(list_type) + + +cpdef LargeListType large_list(value_type): + """ + Create LargeListType instance from child data type or field. + + This data type may not be supported by all Arrow implementations. + Unless you need to represent data larger than 2**31 elements, you should + prefer list_(). + + Parameters + ---------- + value_type : DataType or Field + + Returns + ------- + list_type : DataType + + Examples + -------- + Create an instance of LargeListType: + + >>> import pyarrow as pa + >>> pa.large_list(pa.int8()) + LargeListType(large_list) + + Use the LargeListType to create an array: + + >>> pa.array([[-1, 3]] * 5, type=pa.large_list(pa.int8())) + + [ + [ + -1, + 3 + ], + [ + -1, + 3 + ], + ... + """ + cdef: + DataType data_type + Field _field + shared_ptr[CDataType] list_type + LargeListType out = LargeListType.__new__(LargeListType) + + if isinstance(value_type, DataType): + _field = field('item', value_type) + elif isinstance(value_type, Field): + _field = value_type + else: + raise TypeError('List requires DataType or Field') + + list_type.reset(new CLargeListType(_field.sp_field)) + out.init(list_type) + return out + + +cpdef ListViewType list_view(value_type): + """ + Create ListViewType instance from child data type or field. + + This data type may not be supported by all Arrow implementations + because it is an alternative to the ListType. + + Parameters + ---------- + value_type : DataType or Field + + Returns + ------- + list_view_type : DataType + + Examples + -------- + Create an instance of ListViewType: + + >>> import pyarrow as pa + >>> pa.list_view(pa.string()) + ListViewType(list_view) + """ + cdef: + Field _field + shared_ptr[CDataType] list_view_type + + if isinstance(value_type, DataType): + _field = field('item', value_type) + elif isinstance(value_type, Field): + _field = value_type + else: + raise TypeError('ListView requires DataType or Field') + + list_view_type = CMakeListViewType(_field.sp_field) + return pyarrow_wrap_data_type(list_view_type) + + +cpdef LargeListViewType large_list_view(value_type): + """ + Create LargeListViewType instance from child data type or field. + + This data type may not be supported by all Arrow implementations + because it is an alternative to the ListType. + + Parameters + ---------- + value_type : DataType or Field + + Returns + ------- + list_view_type : DataType + + Examples + -------- + Create an instance of LargeListViewType: + + >>> import pyarrow as pa + >>> pa.large_list_view(pa.int8()) + LargeListViewType(large_list_view) + """ + cdef: + Field _field + shared_ptr[CDataType] list_view_type + + if isinstance(value_type, DataType): + _field = field('item', value_type) + elif isinstance(value_type, Field): + _field = value_type + else: + raise TypeError('LargeListView requires DataType or Field') + + list_view_type = CMakeLargeListViewType(_field.sp_field) + return pyarrow_wrap_data_type(list_view_type) + + +cpdef MapType map_(key_type, item_type, keys_sorted=False): + """ + Create MapType instance from key and item data types or fields. + + Parameters + ---------- + key_type : DataType or Field + item_type : DataType or Field + keys_sorted : bool + + Returns + ------- + map_type : DataType + + Examples + -------- + Create an instance of MapType: + + >>> import pyarrow as pa + >>> pa.map_(pa.string(), pa.int32()) + MapType(map) + >>> pa.map_(pa.string(), pa.int32(), keys_sorted=True) + MapType(map) + + Use MapType to create an array: + + >>> data = [[{'key': 'a', 'value': 1}, {'key': 'b', 'value': 2}], [{'key': 'c', 'value': 3}]] + >>> pa.array(data, type=pa.map_(pa.string(), pa.int32(), keys_sorted=True)) + + [ + keys: + [ + "a", + "b" + ] + values: + [ + 1, + 2 + ], + keys: + [ + "c" + ] + values: + [ + 3 + ] + ] + """ + cdef: + Field _key_field + Field _item_field + shared_ptr[CDataType] map_type + MapType out = MapType.__new__(MapType) + + if isinstance(key_type, Field): + if key_type.nullable: + raise TypeError('Map key field should be non-nullable') + _key_field = key_type + else: + _key_field = field('key', ensure_type(key_type, allow_none=False), + nullable=False) + + if isinstance(item_type, Field): + _item_field = item_type + else: + _item_field = field('value', ensure_type(item_type, allow_none=False)) + + map_type.reset(new CMapType(_key_field.sp_field, _item_field.sp_field, + keys_sorted)) + out.init(map_type) + return out + + +cpdef DictionaryType dictionary(index_type, value_type, bint ordered=False): + """ + Dictionary (categorical, or simply encoded) type. + + Parameters + ---------- + index_type : DataType + value_type : DataType + ordered : bool + + Returns + ------- + type : DictionaryType + + Examples + -------- + Create an instance of dictionary type: + + >>> import pyarrow as pa + >>> pa.dictionary(pa.int64(), pa.utf8()) + DictionaryType(dictionary) + + Use dictionary type to create an array: + + >>> pa.array(["a", "b", None, "d"], pa.dictionary(pa.int64(), pa.utf8())) + + ... + -- dictionary: + [ + "a", + "b", + "d" + ] + -- indices: + [ + 0, + 1, + null, + 2 + ] + """ + cdef: + DataType _index_type = ensure_type(index_type, allow_none=False) + DataType _value_type = ensure_type(value_type, allow_none=False) + DictionaryType out = DictionaryType.__new__(DictionaryType) + shared_ptr[CDataType] dict_type + + if _index_type.id not in { + Type_INT8, Type_INT16, Type_INT32, Type_INT64, + Type_UINT8, Type_UINT16, Type_UINT32, Type_UINT64, + }: + raise TypeError("The dictionary index type should be integer.") + + dict_type.reset(new CDictionaryType(_index_type.sp_type, + _value_type.sp_type, ordered == 1)) + out.init(dict_type) + return out + + +def struct(fields): + """ + Create StructType instance from fields. + + A struct is a nested type parameterized by an ordered sequence of types + (which can all be distinct), called its fields. + + Parameters + ---------- + fields : iterable of Fields or tuples, or mapping of strings to DataTypes + Each field must have a UTF8-encoded name, and these field names are + part of the type metadata. + + Examples + -------- + Create an instance of StructType from an iterable of tuples: + + >>> import pyarrow as pa + >>> fields = [ + ... ('f1', pa.int32()), + ... ('f2', pa.string()), + ... ] + >>> struct_type = pa.struct(fields) + >>> struct_type + StructType(struct) + + Retrieve a field from a StructType: + + >>> struct_type[0] + pyarrow.Field + >>> struct_type['f1'] + pyarrow.Field + + Create an instance of StructType from an iterable of Fields: + + >>> fields = [ + ... pa.field('f1', pa.int32()), + ... pa.field('f2', pa.string(), nullable=False), + ... ] + >>> pa.struct(fields) + StructType(struct) + + Returns + ------- + type : DataType + """ + cdef: + Field py_field + vector[shared_ptr[CField]] c_fields + cdef shared_ptr[CDataType] struct_type + + if isinstance(fields, Mapping): + fields = fields.items() + + for item in fields: + if isinstance(item, tuple): + py_field = field(*item) + else: + py_field = item + c_fields.push_back(py_field.sp_field) + + struct_type.reset(new CStructType(c_fields)) + return pyarrow_wrap_data_type(struct_type) + + +cdef _extract_union_params(child_fields, type_codes, + vector[shared_ptr[CField]]* c_fields, + vector[int8_t]* c_type_codes): + cdef: + Field child_field + + for child_field in child_fields: + c_fields[0].push_back(child_field.sp_field) + + if type_codes is not None: + if len(type_codes) != (c_fields.size()): + raise ValueError("type_codes should have the same length " + "as fields") + for code in type_codes: + c_type_codes[0].push_back(code) + else: + c_type_codes[0] = range(c_fields.size()) + + +def sparse_union(child_fields, type_codes=None): + """ + Create SparseUnionType from child fields. + + A sparse union is a nested type where each logical value is taken from + a single child. A buffer of 8-bit type ids indicates which child + a given logical value is to be taken from. + + In a sparse union, each child array should have the same length as the + union array, regardless of the actual number of union values that + refer to it. + + Parameters + ---------- + child_fields : sequence of Field values + Each field must have a UTF8-encoded name, and these field names are + part of the type metadata. + type_codes : list of integers, default None + + Returns + ------- + type : SparseUnionType + """ + cdef: + vector[shared_ptr[CField]] c_fields + vector[int8_t] c_type_codes + + _extract_union_params(child_fields, type_codes, + &c_fields, &c_type_codes) + + return pyarrow_wrap_data_type( + CMakeSparseUnionType(move(c_fields), move(c_type_codes))) + + +def dense_union(child_fields, type_codes=None): + """ + Create DenseUnionType from child fields. + + A dense union is a nested type where each logical value is taken from + a single child, at a specific offset. A buffer of 8-bit type ids + indicates which child a given logical value is to be taken from, + and a buffer of 32-bit offsets indicates at which physical position + in the given child array the logical value is to be taken from. + + Unlike a sparse union, a dense union allows encoding only the child array + values which are actually referred to by the union array. This is + counterbalanced by the additional footprint of the offsets buffer, and + the additional indirection cost when looking up values. + + Parameters + ---------- + child_fields : sequence of Field values + Each field must have a UTF8-encoded name, and these field names are + part of the type metadata. + type_codes : list of integers, default None + + Returns + ------- + type : DenseUnionType + """ + cdef: + vector[shared_ptr[CField]] c_fields + vector[int8_t] c_type_codes + + _extract_union_params(child_fields, type_codes, + &c_fields, &c_type_codes) + + return pyarrow_wrap_data_type( + CMakeDenseUnionType(move(c_fields), move(c_type_codes))) + + +def union(child_fields, mode, type_codes=None): + """ + Create UnionType from child fields. + + A union is a nested type where each logical value is taken from a + single child. A buffer of 8-bit type ids indicates which child + a given logical value is to be taken from. + + Unions come in two flavors: sparse and dense + (see also `pyarrow.sparse_union` and `pyarrow.dense_union`). + + Parameters + ---------- + child_fields : sequence of Field values + Each field must have a UTF8-encoded name, and these field names are + part of the type metadata. + mode : str + Must be 'sparse' or 'dense' + type_codes : list of integers, default None + + Returns + ------- + type : UnionType + """ + if isinstance(mode, int): + if mode not in (_UnionMode_SPARSE, _UnionMode_DENSE): + raise ValueError("Invalid union mode {0!r}".format(mode)) + else: + if mode == 'sparse': + mode = _UnionMode_SPARSE + elif mode == 'dense': + mode = _UnionMode_DENSE + else: + raise ValueError("Invalid union mode {0!r}".format(mode)) + + if mode == _UnionMode_SPARSE: + return sparse_union(child_fields, type_codes) + else: + return dense_union(child_fields, type_codes) + + +def run_end_encoded(run_end_type, value_type): + """ + Create RunEndEncodedType from run-end and value types. + + Parameters + ---------- + run_end_type : pyarrow.DataType + The integer type of the run_ends array. Must be 'int16', 'int32', or 'int64'. + value_type : pyarrow.DataType + The type of the values array. + + Returns + ------- + type : RunEndEncodedType + """ + cdef: + DataType _run_end_type = ensure_type(run_end_type, allow_none=False) + DataType _value_type = ensure_type(value_type, allow_none=False) + shared_ptr[CDataType] ree_type + + if not _run_end_type.type.id() in [_Type_INT16, _Type_INT32, _Type_INT64]: + raise ValueError("The run_end_type should be 'int16', 'int32', or 'int64'") + ree_type = CMakeRunEndEncodedType(_run_end_type.sp_type, _value_type.sp_type) + return pyarrow_wrap_data_type(ree_type) + + +def json_(DataType storage_type=utf8()): + """ + Create instance of JSON extension type. + + Parameters + ---------- + storage_type : DataType, default pyarrow.string() + The underlying data type. Can be on of the following types: + string, large_string, string_view. + + Returns + ------- + type : JsonType + + Examples + -------- + Create an instance of JSON extension type: + + >>> import pyarrow as pa + >>> pa.json_(pa.utf8()) + JsonType(extension) + + Use the JSON type to create an array: + + >>> pa.array(['{"a": 1}', '{"b": 2}'], type=pa.json_(pa.utf8())) + + [ + "{"a": 1}", + "{"b": 2}" + ] + """ + + cdef JsonType out = JsonType.__new__(JsonType) + c_json_ext_type = GetResultValue(CJsonType.Make(storage_type.sp_type)) + out.init(c_json_ext_type) + return out + + +def uuid(): + """ + Create UuidType instance. + + Returns + ------- + type : UuidType + """ + + cdef UuidType out = UuidType.__new__(UuidType) + c_uuid_ext_type = GetResultValue(CUuidType.Make()) + out.init(c_uuid_ext_type) + return out + + +def fixed_shape_tensor(DataType value_type, shape, dim_names=None, permutation=None): + """ + Create instance of fixed shape tensor extension type with shape and optional + names of tensor dimensions and indices of the desired logical + ordering of dimensions. + + Parameters + ---------- + value_type : DataType + Data type of individual tensor elements. + shape : tuple or list of integers + The physical shape of the contained tensors. + dim_names : tuple or list of strings, default None + Explicit names to tensor dimensions. + permutation : tuple or list integers, default None + Indices of the desired ordering of the original dimensions. + The indices contain a permutation of the values ``[0, 1, .., N-1]`` where + N is the number of dimensions. The permutation indicates which dimension + of the logical layout corresponds to which dimension of the physical tensor. + For more information on this parameter see + :ref:`fixed_shape_tensor_extension`. + + Examples + -------- + Create an instance of fixed shape tensor extension type: + + >>> import pyarrow as pa + >>> tensor_type = pa.fixed_shape_tensor(pa.int32(), [2, 2]) + >>> tensor_type + FixedShapeTensorType(extension) + + Inspect the data type: + + >>> tensor_type.value_type + DataType(int32) + >>> tensor_type.shape + [2, 2] + + Create a table with fixed shape tensor extension array: + + >>> arr = [[1, 2, 3, 4], [10, 20, 30, 40], [100, 200, 300, 400]] + >>> storage = pa.array(arr, pa.list_(pa.int32(), 4)) + >>> tensor = pa.ExtensionArray.from_storage(tensor_type, storage) + >>> pa.table([tensor], names=["tensor_array"]) + pyarrow.Table + tensor_array: extension + ---- + tensor_array: [[[1,2,3,4],[10,20,30,40],[100,200,300,400]]] + + Create an instance of fixed shape tensor extension type with names + of tensor dimensions: + + >>> tensor_type = pa.fixed_shape_tensor(pa.int8(), (2, 2, 3), + ... dim_names=['C', 'H', 'W']) + >>> tensor_type.dim_names + ['C', 'H', 'W'] + + Create an instance of fixed shape tensor extension type with + permutation: + + >>> tensor_type = pa.fixed_shape_tensor(pa.int8(), (2, 2, 3), + ... permutation=[0, 2, 1]) + >>> tensor_type.permutation + [0, 2, 1] + + Returns + ------- + type : FixedShapeTensorType + """ + + cdef: + vector[int64_t] c_shape + vector[int64_t] c_permutation + vector[c_string] c_dim_names + shared_ptr[CDataType] c_tensor_ext_type + + assert value_type is not None + assert shape is not None + + for i in shape: + c_shape.push_back(i) + + if permutation is not None: + for i in permutation: + c_permutation.push_back(i) + + if dim_names is not None: + for x in dim_names: + c_dim_names.push_back(tobytes(x)) + + cdef FixedShapeTensorType out = FixedShapeTensorType.__new__(FixedShapeTensorType) + + with nogil: + c_tensor_ext_type = GetResultValue(CFixedShapeTensorType.Make( + value_type.sp_type, c_shape, c_permutation, c_dim_names)) + + out.init(c_tensor_ext_type) + + return out + + +def bool8(): + """ + Create instance of bool8 extension type. + + Examples + -------- + Create an instance of bool8 extension type: + + >>> import pyarrow as pa + >>> type = pa.bool8() + >>> type + Bool8Type(extension) + + Inspect the data type: + + >>> type.storage_type + DataType(int8) + + Create a table with a bool8 array: + + >>> arr = [-1, 0, 1, 2, None] + >>> storage = pa.array(arr, pa.int8()) + >>> other = pa.ExtensionArray.from_storage(type, storage) + >>> pa.table([other], names=["unknown_col"]) + pyarrow.Table + unknown_col: extension + ---- + unknown_col: [[-1,0,1,2,null]] + + Returns + ------- + type : Bool8Type + """ + + cdef Bool8Type out = Bool8Type.__new__(Bool8Type) + + c_type = GetResultValue(CBool8Type.Make()) + + out.init(c_type) + + return out + + +def opaque(DataType storage_type, str type_name not None, str vendor_name not None): + """ + Create instance of opaque extension type. + + Parameters + ---------- + storage_type : DataType + The underlying data type. + type_name : str + The name of the type in the external system. + vendor_name : str + The name of the external system. + + Examples + -------- + Create an instance of an opaque extension type: + + >>> import pyarrow as pa + >>> type = pa.opaque(pa.binary(), "other", "jdbc") + >>> type + OpaqueType(extension) + + Inspect the data type: + + >>> type.storage_type + DataType(binary) + >>> type.type_name + 'other' + >>> type.vendor_name + 'jdbc' + + Create a table with an opaque array: + + >>> arr = [None, b"foobar"] + >>> storage = pa.array(arr, pa.binary()) + >>> other = pa.ExtensionArray.from_storage(type, storage) + >>> pa.table([other], names=["unknown_col"]) + pyarrow.Table + unknown_col: extension + ---- + unknown_col: [[null,666F6F626172]] + + Returns + ------- + type : OpaqueType + """ + + cdef: + c_string c_type_name = tobytes(type_name) + c_string c_vendor_name = tobytes(vendor_name) + shared_ptr[COpaqueType] c_opaque_type = make_shared[COpaqueType]( + storage_type.sp_type, c_type_name, c_vendor_name) + shared_ptr[CDataType] c_type = static_pointer_cast[CDataType, COpaqueType](c_opaque_type) + OpaqueType out = OpaqueType.__new__(OpaqueType) + out.init(c_type) + return out + + +cdef dict _type_aliases = { + 'null': null, + 'bool': bool_, + 'boolean': bool_, + 'i1': int8, + 'int8': int8, + 'i2': int16, + 'int16': int16, + 'i4': int32, + 'int32': int32, + 'i8': int64, + 'int64': int64, + 'u1': uint8, + 'uint8': uint8, + 'u2': uint16, + 'uint16': uint16, + 'u4': uint32, + 'uint32': uint32, + 'u8': uint64, + 'uint64': uint64, + 'f2': float16, + 'halffloat': float16, + 'float16': float16, + 'f4': float32, + 'float': float32, + 'float32': float32, + 'f8': float64, + 'double': float64, + 'float64': float64, + 'string': string, + 'str': string, + 'utf8': string, + 'binary': binary, + 'large_string': large_string, + 'large_str': large_string, + 'large_utf8': large_string, + 'large_binary': large_binary, + 'binary_view': binary_view, + 'string_view': string_view, + 'date32': date32, + 'date64': date64, + 'date32[day]': date32, + 'date64[ms]': date64, + 'time32[s]': time32('s'), + 'time32[ms]': time32('ms'), + 'time64[us]': time64('us'), + 'time64[ns]': time64('ns'), + 'timestamp[s]': timestamp('s'), + 'timestamp[ms]': timestamp('ms'), + 'timestamp[us]': timestamp('us'), + 'timestamp[ns]': timestamp('ns'), + 'duration[s]': duration('s'), + 'duration[ms]': duration('ms'), + 'duration[us]': duration('us'), + 'duration[ns]': duration('ns'), + 'month_day_nano_interval': month_day_nano_interval(), +} + + +def type_for_alias(name): + """ + Return DataType given a string alias if one exists. + + Parameters + ---------- + name : str + The alias of the DataType that should be retrieved. + + Returns + ------- + type : DataType + """ + name = name.lower() + try: + alias = _type_aliases[name] + except KeyError: + raise ValueError('No type alias for {0}'.format(name)) + + if isinstance(alias, DataType): + return alias + return alias() + + +cpdef DataType ensure_type(object ty, bint allow_none=False): + if allow_none and ty is None: + return None + elif isinstance(ty, DataType): + return ty + elif isinstance(ty, str): + return type_for_alias(ty) + else: + raise TypeError('DataType expected, got {!r}'.format(type(ty))) + + +def schema(fields, metadata=None): + """ + Construct pyarrow.Schema from collection of fields. + + Parameters + ---------- + fields : iterable of Fields or tuples, or mapping of strings to DataTypes + Can also pass an object that implements the Arrow PyCapsule Protocol + for schemas (has an ``__arrow_c_schema__`` method). + metadata : dict, default None + Keys and values must be coercible to bytes. + + Examples + -------- + Create a Schema from iterable of tuples: + + >>> import pyarrow as pa + >>> pa.schema([ + ... ('some_int', pa.int32()), + ... ('some_string', pa.string()), + ... pa.field('some_required_string', pa.string(), nullable=False) + ... ]) + some_int: int32 + some_string: string + some_required_string: string not null + + Create a Schema from iterable of Fields: + + >>> pa.schema([ + ... pa.field('some_int', pa.int32()), + ... pa.field('some_string', pa.string()) + ... ]) + some_int: int32 + some_string: string + + DataTypes can also be passed as strings. The following is equivalent to the + above example: + + >>> pa.schema([ + ... pa.field('some_int', "int32"), + ... pa.field('some_string', "string") + ... ]) + some_int: int32 + some_string: string + + Or more concisely: + + >>> pa.schema([ + ... ('some_int', "int32"), + ... ('some_string', "string") + ... ]) + some_int: int32 + some_string: string + + Returns + ------- + schema : pyarrow.Schema + """ + cdef: + shared_ptr[const CKeyValueMetadata] c_meta + shared_ptr[CSchema] c_schema + Schema result + Field py_field + vector[shared_ptr[CField]] c_fields + + if hasattr(fields, "__arrow_c_schema__"): + result = Schema._import_from_c_capsule(fields.__arrow_c_schema__()) + if metadata is not None: + result = result.with_metadata(metadata) + return result + + if isinstance(fields, Mapping): + fields = fields.items() + + for item in fields: + if isinstance(item, tuple): + py_field = field(*item) + else: + py_field = item + if py_field is None: + raise TypeError("field or tuple expected, got None") + c_fields.push_back(py_field.sp_field) + + metadata = ensure_metadata(metadata, allow_none=True) + c_meta = pyarrow_unwrap_metadata(metadata) + + c_schema.reset(new CSchema(c_fields, c_meta)) + result = Schema.__new__(Schema) + result.init_schema(c_schema) + + return result + + +def from_numpy_dtype(object dtype): + """ + Convert NumPy dtype to pyarrow.DataType. + + Parameters + ---------- + dtype : the numpy dtype to convert + + + Examples + -------- + Create a pyarrow DataType from NumPy dtype: + + >>> import pyarrow as pa + >>> import numpy as np + >>> pa.from_numpy_dtype(np.dtype('float16')) + DataType(halffloat) + >>> pa.from_numpy_dtype('U') + DataType(string) + >>> pa.from_numpy_dtype(bool) + DataType(bool) + >>> pa.from_numpy_dtype(np.str_) + DataType(string) + """ + dtype = np.dtype(dtype) + return pyarrow_wrap_data_type(GetResultValue(NumPyDtypeToArrow(dtype))) + + +def is_boolean_value(object obj): + """ + Check if the object is a boolean. + + Parameters + ---------- + obj : object + The object to check + """ + return IsPyBool(obj) + + +def is_integer_value(object obj): + """ + Check if the object is an integer. + + Parameters + ---------- + obj : object + The object to check + """ + return IsPyInt(obj) + + +def is_float_value(object obj): + """ + Check if the object is a float. + + Parameters + ---------- + obj : object + The object to check + """ + return IsPyFloat(obj) + + +cdef class _ExtensionRegistryNanny(_Weakrefable): + # Keep the registry alive until we have unregistered PyExtensionType + cdef: + shared_ptr[CExtensionTypeRegistry] registry + + def __cinit__(self): + self.registry = CExtensionTypeRegistry.GetGlobalRegistry() + + def release_registry(self): + self.registry.reset() + + +_registry_nanny = _ExtensionRegistryNanny() + + +def _register_py_extension_type(): + cdef: + DataType storage_type + shared_ptr[CExtensionType] cpy_ext_type + c_string c_extension_name = tobytes("arrow.py_extension_type") + + # Make a dummy C++ ExtensionType + storage_type = null() + check_status(CPyExtensionType.FromClass( + storage_type.sp_type, c_extension_name, PyExtensionType, + &cpy_ext_type)) + check_status( + RegisterPyExtensionType( cpy_ext_type)) + + +def _unregister_py_extension_types(): + # This needs to be done explicitly before the Python interpreter is + # finalized. If the C++ type is destroyed later in the process + # teardown stage, it will invoke CPython APIs such as Py_DECREF + # with a destroyed interpreter. + unregister_extension_type("arrow.py_extension_type") + for ext_type in _python_extension_types_registry: + try: + unregister_extension_type(ext_type.extension_name) + except KeyError: + pass + _registry_nanny.release_registry() + + +_register_py_extension_type() +atexit.register(_unregister_py_extension_types) + + +# +# PyCapsule export utilities +# + +cdef void pycapsule_schema_deleter(object schema_capsule) noexcept: + cdef ArrowSchema* schema = PyCapsule_GetPointer( + schema_capsule, 'arrow_schema' + ) + if schema.release != NULL: + schema.release(schema) + + free(schema) + +cdef object alloc_c_schema(ArrowSchema** c_schema): + c_schema[0] = malloc(sizeof(ArrowSchema)) + # Ensure the capsule destructor doesn't call a random release pointer + c_schema[0].release = NULL + return PyCapsule_New(c_schema[0], 'arrow_schema', &pycapsule_schema_deleter) + + +cdef void pycapsule_array_deleter(object array_capsule) noexcept: + cdef: + ArrowArray* array + # Do not invoke the deleter on a used/moved capsule + array = cpython.PyCapsule_GetPointer( + array_capsule, 'arrow_array' + ) + if array.release != NULL: + array.release(array) + + free(array) + +cdef object alloc_c_array(ArrowArray** c_array): + c_array[0] = malloc(sizeof(ArrowArray)) + # Ensure the capsule destructor doesn't call a random release pointer + c_array[0].release = NULL + return PyCapsule_New(c_array[0], 'arrow_array', &pycapsule_array_deleter) + + +cdef void pycapsule_stream_deleter(object stream_capsule) noexcept: + cdef: + ArrowArrayStream* stream + # Do not invoke the deleter on a used/moved capsule + stream = PyCapsule_GetPointer( + stream_capsule, 'arrow_array_stream' + ) + if stream.release != NULL: + stream.release(stream) + + free(stream) + +cdef object alloc_c_stream(ArrowArrayStream** c_stream): + c_stream[0] = malloc(sizeof(ArrowArrayStream)) + # Ensure the capsule destructor doesn't call a random release pointer + c_stream[0].release = NULL + return PyCapsule_New(c_stream[0], 'arrow_array_stream', &pycapsule_stream_deleter) + + +cdef void pycapsule_device_array_deleter(object array_capsule) noexcept: + cdef: + ArrowDeviceArray* device_array + # Do not invoke the deleter on a used/moved capsule + device_array = cpython.PyCapsule_GetPointer( + array_capsule, 'arrow_device_array' + ) + if device_array.array.release != NULL: + device_array.array.release(&device_array.array) + + free(device_array) + + +cdef object alloc_c_device_array(ArrowDeviceArray** c_array): + c_array[0] = malloc(sizeof(ArrowDeviceArray)) + # Ensure the capsule destructor doesn't call a random release pointer + c_array[0].array.release = NULL + return PyCapsule_New( + c_array[0], 'arrow_device_array', &pycapsule_device_array_deleter) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/types.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/types.py new file mode 100644 index 0000000000000000000000000000000000000000..2bb5cfcf8b7393d3bd779ffb4d5fc58eda630501 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pyarrow/types.py @@ -0,0 +1,325 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Tools for dealing with Arrow type metadata in Python + + +from pyarrow.lib import (is_boolean_value, # noqa + is_integer_value, + is_float_value) + +import pyarrow.lib as lib +from pyarrow.util import doc + + +_SIGNED_INTEGER_TYPES = {lib.Type_INT8, lib.Type_INT16, lib.Type_INT32, + lib.Type_INT64} +_UNSIGNED_INTEGER_TYPES = {lib.Type_UINT8, lib.Type_UINT16, lib.Type_UINT32, + lib.Type_UINT64} +_INTEGER_TYPES = _SIGNED_INTEGER_TYPES | _UNSIGNED_INTEGER_TYPES +_FLOATING_TYPES = {lib.Type_HALF_FLOAT, lib.Type_FLOAT, lib.Type_DOUBLE} +_DECIMAL_TYPES = {lib.Type_DECIMAL32, lib.Type_DECIMAL64, lib.Type_DECIMAL128, + lib.Type_DECIMAL256} +_DATE_TYPES = {lib.Type_DATE32, lib.Type_DATE64} +_TIME_TYPES = {lib.Type_TIME32, lib.Type_TIME64} +_INTERVAL_TYPES = {lib.Type_INTERVAL_MONTH_DAY_NANO} +_TEMPORAL_TYPES = ({lib.Type_TIMESTAMP, + lib.Type_DURATION} | _TIME_TYPES | _DATE_TYPES | + _INTERVAL_TYPES) +_UNION_TYPES = {lib.Type_SPARSE_UNION, lib.Type_DENSE_UNION} +_NESTED_TYPES = {lib.Type_LIST, lib.Type_FIXED_SIZE_LIST, lib.Type_LARGE_LIST, + lib.Type_LIST_VIEW, lib.Type_LARGE_LIST_VIEW, + lib.Type_STRUCT, lib.Type_MAP} | _UNION_TYPES + + +@doc(datatype="null") +def is_null(t): + """ + Return True if value is an instance of type: {datatype}. + + Parameters + ---------- + t : DataType + """ + return t.id == lib.Type_NA + + +@doc(is_null, datatype="boolean") +def is_boolean(t): + return t.id == lib.Type_BOOL + + +@doc(is_null, datatype="any integer") +def is_integer(t): + return t.id in _INTEGER_TYPES + + +@doc(is_null, datatype="signed integer") +def is_signed_integer(t): + return t.id in _SIGNED_INTEGER_TYPES + + +@doc(is_null, datatype="unsigned integer") +def is_unsigned_integer(t): + return t.id in _UNSIGNED_INTEGER_TYPES + + +@doc(is_null, datatype="int8") +def is_int8(t): + return t.id == lib.Type_INT8 + + +@doc(is_null, datatype="int16") +def is_int16(t): + return t.id == lib.Type_INT16 + + +@doc(is_null, datatype="int32") +def is_int32(t): + return t.id == lib.Type_INT32 + + +@doc(is_null, datatype="int64") +def is_int64(t): + return t.id == lib.Type_INT64 + + +@doc(is_null, datatype="uint8") +def is_uint8(t): + return t.id == lib.Type_UINT8 + + +@doc(is_null, datatype="uint16") +def is_uint16(t): + return t.id == lib.Type_UINT16 + + +@doc(is_null, datatype="uint32") +def is_uint32(t): + return t.id == lib.Type_UINT32 + + +@doc(is_null, datatype="uint64") +def is_uint64(t): + return t.id == lib.Type_UINT64 + + +@doc(is_null, datatype="floating point numeric") +def is_floating(t): + return t.id in _FLOATING_TYPES + + +@doc(is_null, datatype="float16 (half-precision)") +def is_float16(t): + return t.id == lib.Type_HALF_FLOAT + + +@doc(is_null, datatype="float32 (single precision)") +def is_float32(t): + return t.id == lib.Type_FLOAT + + +@doc(is_null, datatype="float64 (double precision)") +def is_float64(t): + return t.id == lib.Type_DOUBLE + + +@doc(is_null, datatype="list") +def is_list(t): + return t.id == lib.Type_LIST + + +@doc(is_null, datatype="large list") +def is_large_list(t): + return t.id == lib.Type_LARGE_LIST + + +@doc(is_null, datatype="fixed size list") +def is_fixed_size_list(t): + return t.id == lib.Type_FIXED_SIZE_LIST + + +@doc(is_null, datatype="list view") +def is_list_view(t): + return t.id == lib.Type_LIST_VIEW + + +@doc(is_null, datatype="large list view") +def is_large_list_view(t): + return t.id == lib.Type_LARGE_LIST_VIEW + + +@doc(is_null, datatype="struct") +def is_struct(t): + return t.id == lib.Type_STRUCT + + +@doc(is_null, datatype="union") +def is_union(t): + return t.id in _UNION_TYPES + + +@doc(is_null, datatype="nested type") +def is_nested(t): + return t.id in _NESTED_TYPES + + +@doc(is_null, datatype="run-end encoded") +def is_run_end_encoded(t): + return t.id == lib.Type_RUN_END_ENCODED + + +@doc(is_null, datatype="date, time, timestamp or duration") +def is_temporal(t): + return t.id in _TEMPORAL_TYPES + + +@doc(is_null, datatype="timestamp") +def is_timestamp(t): + return t.id == lib.Type_TIMESTAMP + + +@doc(is_null, datatype="duration") +def is_duration(t): + return t.id == lib.Type_DURATION + + +@doc(is_null, datatype="time") +def is_time(t): + return t.id in _TIME_TYPES + + +@doc(is_null, datatype="time32") +def is_time32(t): + return t.id == lib.Type_TIME32 + + +@doc(is_null, datatype="time64") +def is_time64(t): + return t.id == lib.Type_TIME64 + + +@doc(is_null, datatype="variable-length binary") +def is_binary(t): + return t.id == lib.Type_BINARY + + +@doc(is_null, datatype="large variable-length binary") +def is_large_binary(t): + return t.id == lib.Type_LARGE_BINARY + + +@doc(method="is_string") +def is_unicode(t): + """ + Alias for {method}. + + Parameters + ---------- + t : DataType + """ + return is_string(t) + + +@doc(is_null, datatype="string (utf8 unicode)") +def is_string(t): + return t.id == lib.Type_STRING + + +@doc(is_unicode, method="is_large_string") +def is_large_unicode(t): + return is_large_string(t) + + +@doc(is_null, datatype="large string (utf8 unicode)") +def is_large_string(t): + return t.id == lib.Type_LARGE_STRING + + +@doc(is_null, datatype="fixed size binary") +def is_fixed_size_binary(t): + return t.id == lib.Type_FIXED_SIZE_BINARY + + +@doc(is_null, datatype="variable-length binary view") +def is_binary_view(t): + return t.id == lib.Type_BINARY_VIEW + + +@doc(is_null, datatype="variable-length string (utf-8) view") +def is_string_view(t): + return t.id == lib.Type_STRING_VIEW + + +@doc(is_null, datatype="date") +def is_date(t): + return t.id in _DATE_TYPES + + +@doc(is_null, datatype="date32 (days)") +def is_date32(t): + return t.id == lib.Type_DATE32 + + +@doc(is_null, datatype="date64 (milliseconds)") +def is_date64(t): + return t.id == lib.Type_DATE64 + + +@doc(is_null, datatype="map") +def is_map(t): + return t.id == lib.Type_MAP + + +@doc(is_null, datatype="decimal") +def is_decimal(t): + return t.id in _DECIMAL_TYPES + + +@doc(is_null, datatype="decimal32") +def is_decimal32(t): + return t.id == lib.Type_DECIMAL32 + + +@doc(is_null, datatype="decimal64") +def is_decimal64(t): + return t.id == lib.Type_DECIMAL64 + + +@doc(is_null, datatype="decimal128") +def is_decimal128(t): + return t.id == lib.Type_DECIMAL128 + + +@doc(is_null, datatype="decimal256") +def is_decimal256(t): + return t.id == lib.Type_DECIMAL256 + + +@doc(is_null, datatype="dictionary-encoded") +def is_dictionary(t): + return t.id == lib.Type_DICTIONARY + + +@doc(is_null, datatype="interval") +def is_interval(t): + return t.id == lib.Type_INTERVAL_MONTH_DAY_NANO + + +@doc(is_null, datatype="primitive type") +def is_primitive(t): + return lib._is_primitive(t.id) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pytablewriter/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pytablewriter/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..609501c8c5930a400204046b9dfa63a8132e9086 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pytablewriter/__init__.py @@ -0,0 +1,133 @@ +""" +.. codeauthor:: Tsuyoshi Hombashi +""" + +from dataproperty import LineBreakHandling + +from .__version__ import __author__, __copyright__, __email__, __license__, __version__ +from ._factory import TableWriterFactory +from ._function import dumps_tabledata +from ._logger import set_logger +from ._table_format import FormatAttr, TableFormat +from .error import ( + EmptyTableDataError, + EmptyTableNameError, + EmptyValueError, + NotSupportedError, + WriterNotFoundError, +) +from .style import Align, Format +from .typehint import ( + Bool, + DateTime, + Dictionary, + Infinity, + Integer, + IpAddress, + List, + Nan, + NoneType, + NullString, + RealNumber, + String, +) +from .writer import ( + AbstractTableWriter, + AsciiDocTableWriter, + BoldUnicodeTableWriter, + BorderlessTableWriter, + CssTableWriter, + CsvTableWriter, + ElasticsearchWriter, + ExcelXlsTableWriter, + ExcelXlsxTableWriter, + HtmlTableWriter, + JavaScriptTableWriter, + JsonLinesTableWriter, + JsonTableWriter, + LatexMatrixWriter, + LatexTableWriter, + LtsvTableWriter, + MarkdownTableWriter, + MediaWikiTableWriter, + NullTableWriter, + NumpyTableWriter, + PandasDataFramePickleWriter, + PandasDataFrameWriter, + PythonCodeTableWriter, + RstCsvTableWriter, + RstGridTableWriter, + RstSimpleTableWriter, + SpaceAlignedTableWriter, + SqliteTableWriter, + TomlTableWriter, + TsvTableWriter, + UnicodeTableWriter, + YamlTableWriter, +) + + +__all__ = ( + "__author__", + "__copyright__", + "__email__", + "__license__", + "__version__", + "LineBreakHandling", + "TableWriterFactory", + "dumps_tabledata", + "set_logger", + "FormatAttr", + "TableFormat", + "Align", + "Format", + "Bool", + "DateTime", + "Dictionary", + "Infinity", + "Integer", + "IpAddress", + "List", + "Nan", + "NoneType", + "NullString", + "RealNumber", + "String", + "EmptyTableDataError", + "EmptyTableNameError", + "EmptyValueError", + "NotSupportedError", + "WriterNotFoundError", + "AbstractTableWriter", + "AsciiDocTableWriter", + "BoldUnicodeTableWriter", + "BorderlessTableWriter", + "CssTableWriter", + "CsvTableWriter", + "ElasticsearchWriter", + "ExcelXlsTableWriter", + "ExcelXlsxTableWriter", + "HtmlTableWriter", + "JavaScriptTableWriter", + "JsonLinesTableWriter", + "JsonTableWriter", + "LatexMatrixWriter", + "LatexTableWriter", + "LtsvTableWriter", + "MarkdownTableWriter", + "MediaWikiTableWriter", + "NullTableWriter", + "NumpyTableWriter", + "PandasDataFramePickleWriter", + "PandasDataFrameWriter", + "PythonCodeTableWriter", + "RstCsvTableWriter", + "RstGridTableWriter", + "RstSimpleTableWriter", + "SpaceAlignedTableWriter", + "SqliteTableWriter", + "TomlTableWriter", + "TsvTableWriter", + "UnicodeTableWriter", + "YamlTableWriter", +) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pytablewriter/__version__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pytablewriter/__version__.py new file mode 100644 index 0000000000000000000000000000000000000000..3742416fce9d48332e755b67aefb6b1ff6699db1 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pytablewriter/__version__.py @@ -0,0 +1,9 @@ +from typing import Final + + +__author__: Final = "Tsuyoshi Hombashi" +__copyright__: Final = f"Copyright 2016-2025, {__author__}" +__license__: Final = "MIT License" +__version__ = "1.2.1" +__maintainer__: Final = __author__ +__email__: Final = "tsuyoshi.hombashi@gmail.com" diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pytablewriter/_converter.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pytablewriter/_converter.py new file mode 100644 index 0000000000000000000000000000000000000000..2a8d70441a2dd20e4184eda5ab80bc4a5d116960 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pytablewriter/_converter.py @@ -0,0 +1,11 @@ +""" +.. codeauthor:: Tsuyoshi Hombashi +""" + +import re + + +def strip_quote(text: str, value: str) -> str: + re_replace = re.compile(f"[\"']{value:s}[\"']", re.MULTILINE) + + return re_replace.sub(value, text) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pytablewriter/_factory.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pytablewriter/_factory.py new file mode 100644 index 0000000000000000000000000000000000000000..0244a1cc7e8912800038b7ac6dc64513227029d0 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pytablewriter/_factory.py @@ -0,0 +1,274 @@ +""" +.. codeauthor:: Tsuyoshi Hombashi +""" + +import os +from itertools import chain +from typing import Any, Final + +import typepy + +from ._logger import logger +from ._table_format import FormatAttr, TableFormat +from .error import WriterNotFoundError +from .writer import AbstractTableWriter + + +class TableWriterFactory: + """ + A factory class of table writer classes. + """ + + @classmethod + def create_from_file_extension(cls, file_extension: str, **kwargs: Any) -> AbstractTableWriter: + """ + Create a table writer class instance from a file extension. + Supported file extensions are as follows: + + ================== =================================== + Extension Writer Class + ================== =================================== + ``".adoc"`` :py:class:`~.AsciiDocTableWriter` + ``".asciidoc"`` :py:class:`~.AsciiDocTableWriter` + ``".asc"`` :py:class:`~.AsciiDocTableWriter` + ``".css"`` :py:class:`~.CssTableWriter` + ``".csv"`` :py:class:`~.CsvTableWriter` + ``".htm"`` :py:class:`~.HtmlTableWriter` + ``".html"`` :py:class:`~.HtmlTableWriter` + ``".js"`` :py:class:`~.JavaScriptTableWriter` + ``".json"`` :py:class:`~.JsonTableWriter` + ``".jsonl"`` :py:class:`~.JsonLinesTableWriter` + ``".ltsv"`` :py:class:`~.LtsvTableWriter` + ``".ldjson"`` :py:class:`~.JsonLinesTableWriter` + ``".md"`` :py:class:`~.MarkdownTableWriter` + ``".ndjson"`` :py:class:`~.JsonLinesTableWriter` + ``".py"`` :py:class:`~.PythonCodeTableWriter` + ``".rst"`` :py:class:`~.RstGridTableWriter` + ``".tsv"`` :py:class:`~.TsvTableWriter` + ``".xls"`` :py:class:`~.ExcelXlsTableWriter` + ``".xlsx"`` :py:class:`~.ExcelXlsxTableWriter` + ``".sqlite"`` :py:class:`~.SqliteTableWriter` + ``".sqlite3"`` :py:class:`~.SqliteTableWriter` + ``".tsv"`` :py:class:`~.TsvTableWriter` + ``".toml"`` :py:class:`~.TomlTableWriter` + ``".yml"`` :py:class:`~.YamlTableWriter` + ================== =================================== + + :param str file_extension: + File extension string (case insensitive). + :param kwargs: + Keyword arguments that pass to a writer class constructor. + :return: + Writer instance that coincides with the ``file_extension``. + :rtype: + :py:class:`~pytablewriter.writer._table_writer.TableWriterInterface` + :raises pytablewriter.WriterNotFoundError: + |WriterNotFoundError_desc| the file extension. + """ + + ext: Final = os.path.splitext(file_extension)[1] + if typepy.is_null_string(ext): + file_extension = file_extension + else: + file_extension = ext + + file_extension = file_extension.lstrip(".").lower() + + for table_format in TableFormat: + if file_extension not in table_format.file_extensions: + continue + + if table_format.format_attribute & FormatAttr.SECONDARY_EXT: + continue + + logger.debug(f"create a {table_format.writer_class} instance") + + return table_format.writer_class(**kwargs) # type: ignore + + raise WriterNotFoundError( + "\n".join( + [ + f"{file_extension:s} (unknown file extension).", + "", + "acceptable file extensions are: {}.".format(", ".join(cls.get_extensions())), + ] + ) + ) + + @classmethod + def create_from_format_name(cls, format_name: str, **kwargs: Any) -> AbstractTableWriter: + """ + Create a table writer class instance from a format name. + Supported file format names are as follows: + + ============================================= =================================== + Format name Writer Class + ============================================= =================================== + ``"adoc"`` :py:class:`~.AsciiDocTableWriter` + ``"asciidoc"`` :py:class:`~.AsciiDocTableWriter` + ``"css"`` :py:class:`~.CssTableWriter` + ``"csv"`` :py:class:`~.CsvTableWriter` + ``"elasticsearch"`` :py:class:`~.ElasticsearchWriter` + ``"excel"`` :py:class:`~.ExcelXlsxTableWriter` + ``"html"``/``"htm"`` :py:class:`~.HtmlTableWriter` + ``"javascript"``/``"js"`` :py:class:`~.JavaScriptTableWriter` + ``"json"`` :py:class:`~.JsonTableWriter` + ``"json_lines"`` :py:class:`~.JsonLinesTableWriter` + ``"latex_matrix"`` :py:class:`~.LatexMatrixWriter` + ``"latex_table"`` :py:class:`~.LatexTableWriter` + ``"ldjson"`` :py:class:`~.JsonLinesTableWriter` + ``"ltsv"`` :py:class:`~.LtsvTableWriter` + ``"markdown"``/``"md"`` :py:class:`~.MarkdownTableWriter` + ``"mediawiki"`` :py:class:`~.MediaWikiTableWriter` + ``"null"`` :py:class:`~.NullTableWriter` + ``"pandas"`` :py:class:`~.PandasDataFrameWriter` + ``"py"``/``"python"`` :py:class:`~.PythonCodeTableWriter` + ``"rst"``/``"rst_grid"``/``"rst_grid_table"`` :py:class:`~.RstGridTableWriter` + ``"rst_simple"``/``"rst_simple_table"`` :py:class:`~.RstSimpleTableWriter` + ``"rst_csv"``/``"rst_csv_table"`` :py:class:`~.RstCsvTableWriter` + ``"sqlite"`` :py:class:`~.SqliteTableWriter` + ``"ssv"`` :py:class:`~.SpaceAlignedTableWriter` + ``"tsv"`` :py:class:`~.TsvTableWriter` + ``"toml"`` :py:class:`~.TomlTableWriter` + ``"unicode"`` :py:class:`~.UnicodeTableWriter` + ``"yaml"`` :py:class:`~.YamlTableWriter` + ============================================= =================================== + + :param str format_name: + Format name string (case insensitive). + :param kwargs: + Keyword arguments that pass to a writer class constructor. + :return: + Writer instance that coincides with the ``format_name``: + :rtype: + :py:class:`~pytablewriter.writer._table_writer.TableWriterInterface` + :raises pytablewriter.WriterNotFoundError: + |WriterNotFoundError_desc| for the format. + """ + + format_name = format_name.casefold() + + for table_format in TableFormat: + if format_name in table_format.names and not ( + table_format.format_attribute & FormatAttr.SECONDARY_NAME + ): + writer = table_format.writer_class(**kwargs) # type: ignore + logger.debug(f"create a {writer.FORMAT_NAME} instance") + + return writer + + raise WriterNotFoundError( + "\n".join( + [ + f"{format_name} (unknown format name).", + "acceptable format names are: {}.".format(", ".join(cls.get_format_names())), + ] + ) + ) + + @classmethod + def get_format_names(cls) -> list[str]: + """ + :return: Available format names. + :rtype: list + + :Example: + .. code:: python + + >>> import pytablewriter as ptw + >>> for name in ptw.TableWriterFactory.get_format_names(): + ... print(name) + ... + adoc + asciidoc + bold_unicode + borderless + css + csv + elasticsearch + excel + htm + html + javascript + js + json + json_lines + jsonl + latex_matrix + latex_table + ldjson + ltsv + markdown + md + mediawiki + ndjson + null + numpy + pandas + pandas_pickle + py + python + rst + rst_csv + rst_csv_table + rst_grid + rst_grid_table + rst_simple + rst_simple_table + space_aligned + sqlite + ssv + toml + tsv + unicode + yaml + + """ + + return sorted(list(set(chain(*(table_format.names for table_format in TableFormat))))) + + @classmethod + def get_extensions(cls) -> list[str]: + """ + :return: Available file extensions. + :rtype: list + + :Example: + .. code:: python + + >>> import pytablewriter as ptw + >>> for name in ptw.TableWriterFactory.get_extensions(): + ... print(name) + ... + adoc + asc + asciidoc + css + csv + htm + html + js + json + jsonl + ldjson + ltsv + md + ndjson + py + rst + sqlite + sqlite3 + tex + toml + tsv + xls + xlsx + yml + """ + + file_extension_set = set() + for table_format in TableFormat: + for file_extension in table_format.file_extensions: + file_extension_set.add(file_extension) + + return sorted(list(file_extension_set)) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pytablewriter/_function.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pytablewriter/_function.py new file mode 100644 index 0000000000000000000000000000000000000000..0e116a1ae13fda44995680c1400ac3b794da2047 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pytablewriter/_function.py @@ -0,0 +1,84 @@ +""" +.. codeauthor:: Tsuyoshi Hombashi +""" + +from datetime import datetime +from enum import Enum +from typing import Any, Optional + +import dataproperty +from pathvalidate import replace_symbol +from tabledata._core import TableData + + +def quote_datetime_formatter(value: datetime) -> str: + return f'"{value.strftime(dataproperty.DefaultValue.DATETIME_FORMAT):s}"' + + +def dateutil_datetime_formatter(value: datetime) -> str: + return 'dateutil.parser.parse("{:s}")'.format( + value.strftime(dataproperty.DefaultValue.DATETIME_FORMAT) + ) + + +def dumps_tabledata(value: TableData, format_name: str = "rst_grid_table", **kwargs: Any) -> str: + """ + :param tabledata.TableData value: Tabular data to dump. + :param str format_name: + Dumped format name of tabular data. + Available formats are described in + :py:meth:`~pytablewriter.TableWriterFactory.create_from_format_name` + + :Example: + .. code:: python + + >>> dumps_tabledata(value) + .. table:: sample_data + + ====== ====== ====== + attr_a attr_b attr_c + ====== ====== ====== + 1 4.0 a + 2 2.1 bb + 3 120.9 ccc + ====== ====== ====== + """ + + from ._factory import TableWriterFactory + + if not value: + raise TypeError("value must be a tabledata.TableData instance") + + writer = TableWriterFactory.create_from_format_name(format_name) + + for attr_name, attr_value in kwargs.items(): + setattr(writer, attr_name, attr_value) + + writer.from_tabledata(value) + + return writer.dumps() + + +def normalize_enum( + value: Any, enum_class: type[Enum], validate: bool = True, default: Optional[Enum] = None +) -> Any: + if value is None: + return default + + if isinstance(value, enum_class): + return value + + try: + return enum_class[replace_symbol(value.strip(), "_").upper()] + except AttributeError: + if validate: + raise TypeError(f"value must be a {enum_class} or a str: actual={type(value)}") + except KeyError: + if validate: + raise ValueError( + "invalid valid found: expected={}, actual={}".format( + "/".join(item.name for item in enum_class), value + ) + ) + + return value diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pytablewriter/_table_format.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pytablewriter/_table_format.py new file mode 100644 index 0000000000000000000000000000000000000000..fccd04756d2060e608ef83c63038065a9d44bba7 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pytablewriter/_table_format.py @@ -0,0 +1,354 @@ +""" +.. codeauthor:: Tsuyoshi Hombashi +""" + +import enum +from collections.abc import Sequence +from typing import Optional + +from .writer import ( + AbstractTableWriter, + AsciiDocTableWriter, + BoldUnicodeTableWriter, + BorderlessTableWriter, + CssTableWriter, + CsvTableWriter, + ElasticsearchWriter, + ExcelXlsTableWriter, + ExcelXlsxTableWriter, + HtmlTableWriter, + JavaScriptTableWriter, + JsonLinesTableWriter, + JsonTableWriter, + LatexMatrixWriter, + LatexTableWriter, + LtsvTableWriter, + MarkdownTableWriter, + MediaWikiTableWriter, + NullTableWriter, + NumpyTableWriter, + PandasDataFramePickleWriter, + PandasDataFrameWriter, + PythonCodeTableWriter, + RstCsvTableWriter, + RstGridTableWriter, + RstSimpleTableWriter, + SpaceAlignedTableWriter, + SqliteTableWriter, + TomlTableWriter, + TsvTableWriter, + UnicodeTableWriter, + YamlTableWriter, +) + + +class FormatAttr: + """ + Bitmaps to represent table attributes. + """ + + NONE = 1 << 1 + + #: Can create a file with the format. + FILE = 1 << 2 + + #: Table format that can represent as a text. + TEXT = 1 << 3 + + #: Table format that can represent as a binary file. + BIN = 1 << 4 + + #: Can create a source code (variables definition) + #: one of the programming language. + SOURCECODE = 1 << 5 + + #: Can call API for external service. + API = 1 << 6 + + SECONDARY_EXT = 1 << 10 + SECONDARY_NAME = 1 << 11 + + +@enum.unique +class TableFormat(enum.Enum): + """ + Enum to represent table format attributes. + """ + + ASCIIDOC = ( + [AsciiDocTableWriter.FORMAT_NAME, "adoc"], + AsciiDocTableWriter, + FormatAttr.FILE | FormatAttr.TEXT, + ["adoc", "asciidoc", "asc"], + ) + CSV = ([CsvTableWriter.FORMAT_NAME], CsvTableWriter, FormatAttr.FILE | FormatAttr.TEXT, ["csv"]) + CSS = ( + [CssTableWriter.FORMAT_NAME], + CssTableWriter, + FormatAttr.FILE | FormatAttr.TEXT, + ["css"], + ) + ELASTICSEARCH = ( + [ElasticsearchWriter.FORMAT_NAME], # type: ignore + ElasticsearchWriter, + FormatAttr.API, + [], + ) + EXCEL_XLSX = ( + [ExcelXlsxTableWriter.FORMAT_NAME], + ExcelXlsxTableWriter, + FormatAttr.FILE | FormatAttr.BIN, + ["xlsx"], + ) + EXCEL_XLS = ( + [ExcelXlsTableWriter.FORMAT_NAME], + ExcelXlsTableWriter, + FormatAttr.FILE | FormatAttr.BIN | FormatAttr.SECONDARY_NAME, + ["xls"], + ) + HTML = ( + [HtmlTableWriter.FORMAT_NAME, "htm"], + HtmlTableWriter, + FormatAttr.FILE | FormatAttr.TEXT, + ["html", "htm"], + ) + JAVASCRIPT = ( + [JavaScriptTableWriter.FORMAT_NAME, "js"], + JavaScriptTableWriter, + FormatAttr.FILE | FormatAttr.TEXT | FormatAttr.SOURCECODE, + ["js"], + ) + JSON = ( + [JsonTableWriter.FORMAT_NAME], + JsonTableWriter, + FormatAttr.FILE | FormatAttr.TEXT, + ["json"], + ) + JSON_LINES = ( + [JsonLinesTableWriter.FORMAT_NAME, "jsonl", "ldjson", "ndjson"], + JsonLinesTableWriter, + FormatAttr.FILE | FormatAttr.TEXT, + ["jsonl", "ldjson", "ndjson"], + ) + LATEX_MATRIX = ( + [LatexMatrixWriter.FORMAT_NAME], + LatexMatrixWriter, + FormatAttr.FILE | FormatAttr.TEXT, + ["tex"], + ) + LATEX_TABLE = ( + [LatexTableWriter.FORMAT_NAME], + LatexTableWriter, + FormatAttr.FILE | FormatAttr.TEXT | FormatAttr.SECONDARY_EXT, + ["tex"], + ) + LTSV = ( + [LtsvTableWriter.FORMAT_NAME], + LtsvTableWriter, + FormatAttr.FILE | FormatAttr.TEXT, + ["ltsv"], + ) + MARKDOWN = ( + [MarkdownTableWriter.FORMAT_NAME, "md"], + MarkdownTableWriter, + FormatAttr.FILE | FormatAttr.TEXT, + ["md"], + ) + MEDIAWIKI = ( + [MediaWikiTableWriter.FORMAT_NAME], # type: ignore + MediaWikiTableWriter, + FormatAttr.FILE | FormatAttr.TEXT, + [], + ) + NULL = ( + [NullTableWriter.FORMAT_NAME], # type: ignore + NullTableWriter, + FormatAttr.NONE, + [], + ) + NUMPY = ( + [NumpyTableWriter.FORMAT_NAME], + NumpyTableWriter, + FormatAttr.FILE | FormatAttr.TEXT | FormatAttr.SOURCECODE | FormatAttr.SECONDARY_EXT, + ["py"], + ) + PANDAS = ( + [PandasDataFrameWriter.FORMAT_NAME], + PandasDataFrameWriter, + FormatAttr.FILE | FormatAttr.TEXT | FormatAttr.SOURCECODE | FormatAttr.SECONDARY_EXT, + ["py"], + ) + PANDAS_PICKLE = ( + [PandasDataFramePickleWriter.FORMAT_NAME], # type: ignore + PandasDataFramePickleWriter, + FormatAttr.FILE | FormatAttr.BIN, + [], + ) + PYTHON = ( + [PythonCodeTableWriter.FORMAT_NAME, "py"], + PythonCodeTableWriter, + FormatAttr.FILE | FormatAttr.TEXT | FormatAttr.SOURCECODE, + ["py"], + ) + RST_CSV_TABLE = ( + [RstCsvTableWriter.FORMAT_NAME, "rst_csv"], + RstCsvTableWriter, + FormatAttr.FILE | FormatAttr.TEXT | FormatAttr.SECONDARY_EXT, + ["rst"], + ) + RST_GRID_TABLE = ( + [RstGridTableWriter.FORMAT_NAME, "rst_grid", "rst"], + RstGridTableWriter, + FormatAttr.FILE | FormatAttr.TEXT, + ["rst"], + ) + RST_SIMPLE_TABLE = ( + [RstSimpleTableWriter.FORMAT_NAME, "rst_simple"], + RstSimpleTableWriter, + FormatAttr.FILE | FormatAttr.TEXT | FormatAttr.SECONDARY_EXT, + ["rst"], + ) + SPACE_ALIGNED = ( + [SpaceAlignedTableWriter.FORMAT_NAME, "ssv"], # type: ignore + SpaceAlignedTableWriter, + FormatAttr.FILE | FormatAttr.TEXT, + [], + ) + SQLITE = ( + [SqliteTableWriter.FORMAT_NAME], + SqliteTableWriter, + FormatAttr.FILE | FormatAttr.BIN, + ["sqlite", "sqlite3"], + ) + TOML = ( + [TomlTableWriter.FORMAT_NAME], + TomlTableWriter, + FormatAttr.FILE | FormatAttr.TEXT, + ["toml"], + ) + TSV = ([TsvTableWriter.FORMAT_NAME], TsvTableWriter, FormatAttr.FILE | FormatAttr.TEXT, ["tsv"]) + UNICODE = ( + [UnicodeTableWriter.FORMAT_NAME], # type: ignore + UnicodeTableWriter, + FormatAttr.TEXT, + [], + ) + YAML = ( + [YamlTableWriter.FORMAT_NAME], + YamlTableWriter, + FormatAttr.FILE | FormatAttr.TEXT, + ["yml"], + ) + BOLD_UNICODE = ( + [BoldUnicodeTableWriter.FORMAT_NAME], # type: ignore + BoldUnicodeTableWriter, + FormatAttr.TEXT, + [], + ) + BORDERLESS = ( + [BorderlessTableWriter.FORMAT_NAME], # type: ignore + BorderlessTableWriter, + FormatAttr.TEXT, + [], + ) + + @property + def names(self) -> list[str]: + """ + List[str]: Names associated with the table format. + """ + + return self.__names + + @property + def writer_class(self) -> type[AbstractTableWriter]: + """ + Type[AbstractTableWriter]: Table writer class object associated with the table format. + """ + + return self.__writer_class + + @property + def format_attribute(self) -> int: + """ + FormatAttr: Table attributes bitmap. + """ + + return self.__format_attribute + + @property + def file_extensions(self) -> list[str]: + """ + List[str]: File extensions associated with the table format. + """ + + return self.__file_extensions + + def __init__( + self, + names: Sequence[str], + writer_class: type[AbstractTableWriter], + format_attribute: int, + file_extensions: Sequence[str], + ) -> None: + self.__names = list(names) + self.__writer_class = writer_class + self.__format_attribute = format_attribute + self.__file_extensions = list(file_extensions) + + @classmethod + def find_all_attr(cls, format_attribute: int) -> list["TableFormat"]: + """Searching table formats that have specific attributes. + + Args: + format_attribute (FormatAttr): + Table format attributes to look for. + + Returns: + List[TableFormat]: Table formats that matched the attribute. + """ + + return [ + table_format + for table_format in TableFormat + if table_format.format_attribute & format_attribute + ] + + @classmethod + def from_name(cls, format_name: str) -> Optional["TableFormat"]: + """Get a table format from a format name. + + Args: + format_name (str): Table format specifier. + + Returns: + Optional[TableFormat]: A table format enum value corresponding to the ``format_name``. + """ + + format_name = format_name.casefold().strip() + + for table_format in TableFormat: + if format_name in table_format.names: + return table_format + + return None + + @classmethod + def from_file_extension(cls, file_extension: str) -> Optional["TableFormat"]: + """Get a table format from a file extension. + + Args: + file_extension (str): File extension. + + Returns: + Optional[TableFormat]: + A table format enum value corresponding to the ``file_extension``. + """ + + ext = file_extension.lower().strip().lstrip(".") + + for table_format in TableFormat: + if ext in table_format.file_extensions: + return table_format + + return None diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pytablewriter/error.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pytablewriter/error.py new file mode 100644 index 0000000000000000000000000000000000000000..325be5c2e4395d0898fc7d9b3b43acf011ec48b7 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pytablewriter/error.py @@ -0,0 +1,34 @@ +""" +.. codeauthor:: Tsuyoshi Hombashi +""" + + +class NotSupportedError(Exception): + pass + + +class EmptyTableNameError(Exception): + """ + Exception raised when a table writer class of the |table_name| attribute + is null and the class is not accepted null |table_name|. + """ + + +class EmptyValueError(Exception): + """ + Exception raised when a table writer class of the |value_matrix| attribute + is null, and the class is not accepted null |value_matrix|. + """ + + +class EmptyTableDataError(Exception): + """ + Exception raised when a table writer class of the |headers| and + |value_matrix| attributes are null. + """ + + +class WriterNotFoundError(Exception): + """ + Exception raised when appropriate loader writer found. + """ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pytablewriter/py.typed b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/pytablewriter/py.typed new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/safetensors/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/safetensors/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c9a5d2ca92b5248ce798a19f8e14c3492992cae1 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/safetensors/__init__.py @@ -0,0 +1,9 @@ +# Re-export this +from ._safetensors_rust import ( # noqa: F401 + SafetensorError, + __version__, + deserialize, + safe_open, + serialize, + serialize_file, +) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/safetensors/__init__.pyi b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/safetensors/__init__.pyi new file mode 100644 index 0000000000000000000000000000000000000000..7781229fe91d0649996e257dccf9f6d0c38823cd --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/safetensors/__init__.pyi @@ -0,0 +1,149 @@ +# Generated content DO NOT EDIT +@staticmethod +def deserialize(bytes): + """ + Opens a safetensors lazily and returns tensors as asked + + Args: + data (`bytes`): + The byte content of a file + + Returns: + (`List[str, Dict[str, Dict[str, any]]]`): + The deserialized content is like: + [("tensor_name", {"shape": [2, 3], "dtype": "F32", "data": b"\0\0.." }), (...)] + """ + pass + +@staticmethod +def serialize(tensor_dict, metadata=None): + """ + Serializes raw data. + + Args: + tensor_dict (`Dict[str, Dict[Any]]`): + The tensor dict is like: + {"tensor_name": {"dtype": "F32", "shape": [2, 3], "data": b"\0\0"}} + metadata (`Dict[str, str]`, *optional*): + The optional purely text annotations + + Returns: + (`bytes`): + The serialized content. + """ + pass + +@staticmethod +def serialize_file(tensor_dict, filename, metadata=None): + """ + Serializes raw data. + + Args: + tensor_dict (`Dict[str, Dict[Any]]`): + The tensor dict is like: + {"tensor_name": {"dtype": "F32", "shape": [2, 3], "data": b"\0\0"}} + filename (`str`, or `os.PathLike`): + The name of the file to write into. + metadata (`Dict[str, str]`, *optional*): + The optional purely text annotations + + Returns: + (`bytes`): + The serialized content. + """ + pass + +class safe_open: + """ + Opens a safetensors lazily and returns tensors as asked + + Args: + filename (`str`, or `os.PathLike`): + The filename to open + + framework (`str`): + The framework you want you tensors in. Supported values: + `pt`, `tf`, `flax`, `numpy`. + + device (`str`, defaults to `"cpu"`): + The device on which you want the tensors. + """ + + def __init__(self, filename, framework, device=...): + pass + def __enter__(self): + """ + Start the context manager + """ + pass + def __exit__(self, _exc_type, _exc_value, _traceback): + """ + Exits the context manager + """ + pass + def get_slice(self, name): + """ + Returns a full slice view object + + Args: + name (`str`): + The name of the tensor you want + + Returns: + (`PySafeSlice`): + A dummy object you can slice into to get a real tensor + Example: + ```python + from safetensors import safe_open + + with safe_open("model.safetensors", framework="pt", device=0) as f: + tensor_part = f.get_slice("embedding")[:, ::8] + + ``` + """ + pass + def get_tensor(self, name): + """ + Returns a full tensor + + Args: + name (`str`): + The name of the tensor you want + + Returns: + (`Tensor`): + The tensor in the framework you opened the file for. + + Example: + ```python + from safetensors import safe_open + + with safe_open("model.safetensors", framework="pt", device=0) as f: + tensor = f.get_tensor("embedding") + + ``` + """ + pass + def keys(self): + """ + Returns the names of the tensors in the file. + + Returns: + (`List[str]`): + The name of the tensors contained in that file + """ + pass + def metadata(self): + """ + Return the special non tensor information in the header + + Returns: + (`Dict[str, str]`): + The freeform metadata. + """ + pass + +class SafetensorError(Exception): + """ + Custom Python Exception for Safetensor errors. + """ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/safetensors/mlx.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/safetensors/mlx.py new file mode 100644 index 0000000000000000000000000000000000000000..cf9fe37519c817e4d9db87e8ce53c2dc8b85254f --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/safetensors/mlx.py @@ -0,0 +1,138 @@ +import os +from typing import Dict, Optional, Union + +import numpy as np + +import mlx.core as mx +from safetensors import numpy, safe_open + + +def save(tensors: Dict[str, mx.array], metadata: Optional[Dict[str, str]] = None) -> bytes: + """ + Saves a dictionary of tensors into raw bytes in safetensors format. + + Args: + tensors (`Dict[str, mx.array]`): + The incoming tensors. Tensors need to be contiguous and dense. + metadata (`Dict[str, str]`, *optional*, defaults to `None`): + Optional text only metadata you might want to save in your header. + For instance it can be useful to specify more about the underlying + tensors. This is purely informative and does not affect tensor loading. + + Returns: + `bytes`: The raw bytes representing the format + + Example: + + ```python + from safetensors.mlx import save + import mlx.core as mx + + tensors = {"embedding": mx.zeros((512, 1024)), "attention": mx.zeros((256, 256))} + byte_data = save(tensors) + ``` + """ + np_tensors = _mx2np(tensors) + return numpy.save(np_tensors, metadata=metadata) + + +def save_file( + tensors: Dict[str, mx.array], + filename: Union[str, os.PathLike], + metadata: Optional[Dict[str, str]] = None, +) -> None: + """ + Saves a dictionary of tensors into raw bytes in safetensors format. + + Args: + tensors (`Dict[str, mx.array]`): + The incoming tensors. Tensors need to be contiguous and dense. + filename (`str`, or `os.PathLike`)): + The filename we're saving into. + metadata (`Dict[str, str]`, *optional*, defaults to `None`): + Optional text only metadata you might want to save in your header. + For instance it can be useful to specify more about the underlying + tensors. This is purely informative and does not affect tensor loading. + + Returns: + `None` + + Example: + + ```python + from safetensors.mlx import save_file + import mlx.core as mx + + tensors = {"embedding": mx.zeros((512, 1024)), "attention": mx.zeros((256, 256))} + save_file(tensors, "model.safetensors") + ``` + """ + np_tensors = _mx2np(tensors) + return numpy.save_file(np_tensors, filename, metadata=metadata) + + +def load(data: bytes) -> Dict[str, mx.array]: + """ + Loads a safetensors file into MLX format from pure bytes. + + Args: + data (`bytes`): + The content of a safetensors file + + Returns: + `Dict[str, mx.array]`: dictionary that contains name as key, value as `mx.array` + + Example: + + ```python + from safetensors.mlx import load + + file_path = "./my_folder/bert.safetensors" + with open(file_path, "rb") as f: + data = f.read() + + loaded = load(data) + ``` + """ + flat = numpy.load(data) + return _np2mx(flat) + + +def load_file(filename: Union[str, os.PathLike]) -> Dict[str, mx.array]: + """ + Loads a safetensors file into MLX format. + + Args: + filename (`str`, or `os.PathLike`)): + The name of the file which contains the tensors + + Returns: + `Dict[str, mx.array]`: dictionary that contains name as key, value as `mx.array` + + Example: + + ```python + from safetensors.flax import load_file + + file_path = "./my_folder/bert.safetensors" + loaded = load_file(file_path) + ``` + """ + result = {} + with safe_open(filename, framework="mlx") as f: + for k in f.keys(): + result[k] = f.get_tensor(k) + return result + + +def _np2mx(numpy_dict: Dict[str, np.ndarray]) -> Dict[str, mx.array]: + for k, v in numpy_dict.items(): + numpy_dict[k] = mx.array(v) + return numpy_dict + + +def _mx2np(mx_dict: Dict[str, mx.array]) -> Dict[str, np.array]: + new_dict = {} + for k, v in mx_dict.items(): + new_dict[k] = np.asarray(v) + return new_dict diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/safetensors/numpy.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/safetensors/numpy.py new file mode 100644 index 0000000000000000000000000000000000000000..0b245f12c1c949456c9b2edb45a11343e6a8099a --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/safetensors/numpy.py @@ -0,0 +1,176 @@ +import os +import sys +from typing import Dict, Optional, Union + +import numpy as np + +from safetensors import deserialize, safe_open, serialize, serialize_file + + +def _tobytes(tensor: np.ndarray) -> bytes: + if not _is_little_endian(tensor): + tensor = tensor.byteswap(inplace=False) + return tensor.tobytes() + + +def save(tensor_dict: Dict[str, np.ndarray], metadata: Optional[Dict[str, str]] = None) -> bytes: + """ + Saves a dictionary of tensors into raw bytes in safetensors format. + + Args: + tensor_dict (`Dict[str, np.ndarray]`): + The incoming tensors. Tensors need to be contiguous and dense. + metadata (`Dict[str, str]`, *optional*, defaults to `None`): + Optional text only metadata you might want to save in your header. + For instance it can be useful to specify more about the underlying + tensors. This is purely informative and does not affect tensor loading. + + Returns: + `bytes`: The raw bytes representing the format + + Example: + + ```python + from safetensors.numpy import save + import numpy as np + + tensors = {"embedding": np.zeros((512, 1024)), "attention": np.zeros((256, 256))} + byte_data = save(tensors) + ``` + """ + flattened = {k: {"dtype": v.dtype.name, "shape": v.shape, "data": _tobytes(v)} for k, v in tensor_dict.items()} + serialized = serialize(flattened, metadata=metadata) + result = bytes(serialized) + return result + + +def save_file( + tensor_dict: Dict[str, np.ndarray], filename: Union[str, os.PathLike], metadata: Optional[Dict[str, str]] = None +) -> None: + """ + Saves a dictionary of tensors into raw bytes in safetensors format. + + Args: + tensor_dict (`Dict[str, np.ndarray]`): + The incoming tensors. Tensors need to be contiguous and dense. + filename (`str`, or `os.PathLike`)): + The filename we're saving into. + metadata (`Dict[str, str]`, *optional*, defaults to `None`): + Optional text only metadata you might want to save in your header. + For instance it can be useful to specify more about the underlying + tensors. This is purely informative and does not affect tensor loading. + + Returns: + `None` + + Example: + + ```python + from safetensors.numpy import save_file + import numpy as np + + tensors = {"embedding": np.zeros((512, 1024)), "attention": np.zeros((256, 256))} + save_file(tensors, "model.safetensors") + ``` + """ + flattened = {k: {"dtype": v.dtype.name, "shape": v.shape, "data": _tobytes(v)} for k, v in tensor_dict.items()} + serialize_file(flattened, filename, metadata=metadata) + + +def load(data: bytes) -> Dict[str, np.ndarray]: + """ + Loads a safetensors file into numpy format from pure bytes. + + Args: + data (`bytes`): + The content of a safetensors file + + Returns: + `Dict[str, np.ndarray]`: dictionary that contains name as key, value as `np.ndarray` on cpu + + Example: + + ```python + from safetensors.numpy import load + + file_path = "./my_folder/bert.safetensors" + with open(file_path, "rb") as f: + data = f.read() + + loaded = load(data) + ``` + """ + flat = deserialize(data) + return _view2np(flat) + + +def load_file(filename: Union[str, os.PathLike]) -> Dict[str, np.ndarray]: + """ + Loads a safetensors file into numpy format. + + Args: + filename (`str`, or `os.PathLike`)): + The name of the file which contains the tensors + + Returns: + `Dict[str, np.ndarray]`: dictionary that contains name as key, value as `np.ndarray` + + Example: + + ```python + from safetensors.numpy import load_file + + file_path = "./my_folder/bert.safetensors" + loaded = load_file(file_path) + ``` + """ + result = {} + with safe_open(filename, framework="np") as f: + for k in f.keys(): + result[k] = f.get_tensor(k) + return result + + +_TYPES = { + "F64": np.float64, + "F32": np.float32, + "F16": np.float16, + "I64": np.int64, + "U64": np.uint64, + "I32": np.int32, + "U32": np.uint32, + "I16": np.int16, + "U16": np.uint16, + "I8": np.int8, + "U8": np.uint8, + "BOOL": bool, +} + + +def _getdtype(dtype_str: str) -> np.dtype: + return _TYPES[dtype_str] + + +def _view2np(safeview) -> Dict[str, np.ndarray]: + result = {} + for k, v in safeview: + dtype = _getdtype(v["dtype"]) + arr = np.frombuffer(v["data"], dtype=dtype).reshape(v["shape"]) + result[k] = arr + return result + + +def _is_little_endian(tensor: np.ndarray) -> bool: + byteorder = tensor.dtype.byteorder + if byteorder == "=": + if sys.byteorder == "little": + return True + else: + return False + elif byteorder == "|": + return True + elif byteorder == "<": + return True + elif byteorder == ">": + return False + raise ValueError(f"Unexpected byte order {byteorder}") diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/safetensors/paddle.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/safetensors/paddle.py new file mode 100644 index 0000000000000000000000000000000000000000..cec368665de31d17757c0c6621df5dc4926bfab1 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/safetensors/paddle.py @@ -0,0 +1,138 @@ +import os +from typing import Dict, Optional, Union + +import numpy as np + +import paddle +from safetensors import numpy + + +def save(tensors: Dict[str, paddle.Tensor], metadata: Optional[Dict[str, str]] = None) -> bytes: + """ + Saves a dictionary of tensors into raw bytes in safetensors format. + + Args: + tensors (`Dict[str, paddle.Tensor]`): + The incoming tensors. Tensors need to be contiguous and dense. + metadata (`Dict[str, str]`, *optional*, defaults to `None`): + Optional text only metadata you might want to save in your header. + For instance it can be useful to specify more about the underlying + tensors. This is purely informative and does not affect tensor loading. + + Returns: + `bytes`: The raw bytes representing the format + + Example: + + ```python + from safetensors.paddle import save + import paddle + + tensors = {"embedding": paddle.zeros((512, 1024)), "attention": paddle.zeros((256, 256))} + byte_data = save(tensors) + ``` + """ + np_tensors = _paddle2np(tensors) + return numpy.save(np_tensors, metadata=metadata) + + +def save_file( + tensors: Dict[str, paddle.Tensor], + filename: Union[str, os.PathLike], + metadata: Optional[Dict[str, str]] = None, +) -> None: + """ + Saves a dictionary of tensors into raw bytes in safetensors format. + + Args: + tensors (`Dict[str, paddle.Tensor]`): + The incoming tensors. Tensors need to be contiguous and dense. + filename (`str`, or `os.PathLike`)): + The filename we're saving into. + metadata (`Dict[str, str]`, *optional*, defaults to `None`): + Optional text only metadata you might want to save in your header. + For instance it can be useful to specify more about the underlying + tensors. This is purely informative and does not affect tensor loading. + + Returns: + `None` + + Example: + + ```python + from safetensors.paddle import save_file + import paddle + + tensors = {"embedding": paddle.zeros((512, 1024)), "attention": paddle.zeros((256, 256))} + save_file(tensors, "model.safetensors") + ``` + """ + np_tensors = _paddle2np(tensors) + return numpy.save_file(np_tensors, filename, metadata=metadata) + + +def load(data: bytes, device: str = "cpu") -> Dict[str, paddle.Tensor]: + """ + Loads a safetensors file into paddle format from pure bytes. + + Args: + data (`bytes`): + The content of a safetensors file + + Returns: + `Dict[str, paddle.Tensor]`: dictionary that contains name as key, value as `paddle.Tensor` on cpu + + Example: + + ```python + from safetensors.paddle import load + + file_path = "./my_folder/bert.safetensors" + with open(file_path, "rb") as f: + data = f.read() + + loaded = load(data) + ``` + """ + flat = numpy.load(data) + return _np2paddle(flat, device) + + +def load_file(filename: Union[str, os.PathLike], device="cpu") -> Dict[str, paddle.Tensor]: + """ + Loads a safetensors file into paddle format. + + Args: + filename (`str`, or `os.PathLike`)): + The name of the file which contains the tensors + device (`Union[Dict[str, any], str]`, *optional*, defaults to `cpu`): + The device where the tensors need to be located after load. + available options are all regular paddle device locations + + Returns: + `Dict[str, paddle.Tensor]`: dictionary that contains name as key, value as `paddle.Tensor` + + Example: + + ```python + from safetensors.paddle import load_file + + file_path = "./my_folder/bert.safetensors" + loaded = load_file(file_path) + ``` + """ + flat = numpy.load_file(filename) + output = _np2paddle(flat, device) + return output + + +def _np2paddle(numpy_dict: Dict[str, np.ndarray], device: str = "cpu") -> Dict[str, paddle.Tensor]: + for k, v in numpy_dict.items(): + numpy_dict[k] = paddle.to_tensor(v, place=device) + return numpy_dict + + +def _paddle2np(paddle_dict: Dict[str, paddle.Tensor]) -> Dict[str, np.array]: + for k, v in paddle_dict.items(): + paddle_dict[k] = v.detach().cpu().numpy() + return paddle_dict diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/safetensors/py.typed b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/safetensors/py.typed new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/safetensors/tensorflow.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/safetensors/tensorflow.py new file mode 100644 index 0000000000000000000000000000000000000000..e2d74b0522698b3748a7da93753e065f4053beea --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/safetensors/tensorflow.py @@ -0,0 +1,137 @@ +import os +from typing import Dict, Optional, Union + +import numpy as np +import tensorflow as tf + +from safetensors import numpy, safe_open + + +def save(tensors: Dict[str, tf.Tensor], metadata: Optional[Dict[str, str]] = None) -> bytes: + """ + Saves a dictionary of tensors into raw bytes in safetensors format. + + Args: + tensors (`Dict[str, tf.Tensor]`): + The incoming tensors. Tensors need to be contiguous and dense. + metadata (`Dict[str, str]`, *optional*, defaults to `None`): + Optional text only metadata you might want to save in your header. + For instance it can be useful to specify more about the underlying + tensors. This is purely informative and does not affect tensor loading. + + Returns: + `bytes`: The raw bytes representing the format + + Example: + + ```python + from safetensors.tensorflow import save + import tensorflow as tf + + tensors = {"embedding": tf.zeros((512, 1024)), "attention": tf.zeros((256, 256))} + byte_data = save(tensors) + ``` + """ + np_tensors = _tf2np(tensors) + return numpy.save(np_tensors, metadata=metadata) + + +def save_file( + tensors: Dict[str, tf.Tensor], + filename: Union[str, os.PathLike], + metadata: Optional[Dict[str, str]] = None, +) -> None: + """ + Saves a dictionary of tensors into raw bytes in safetensors format. + + Args: + tensors (`Dict[str, tf.Tensor]`): + The incoming tensors. Tensors need to be contiguous and dense. + filename (`str`, or `os.PathLike`)): + The filename we're saving into. + metadata (`Dict[str, str]`, *optional*, defaults to `None`): + Optional text only metadata you might want to save in your header. + For instance it can be useful to specify more about the underlying + tensors. This is purely informative and does not affect tensor loading. + + Returns: + `None` + + Example: + + ```python + from safetensors.tensorflow import save_file + import tensorflow as tf + + tensors = {"embedding": tf.zeros((512, 1024)), "attention": tf.zeros((256, 256))} + save_file(tensors, "model.safetensors") + ``` + """ + np_tensors = _tf2np(tensors) + return numpy.save_file(np_tensors, filename, metadata=metadata) + + +def load(data: bytes) -> Dict[str, tf.Tensor]: + """ + Loads a safetensors file into tensorflow format from pure bytes. + + Args: + data (`bytes`): + The content of a safetensors file + + Returns: + `Dict[str, tf.Tensor]`: dictionary that contains name as key, value as `tf.Tensor` on cpu + + Example: + + ```python + from safetensors.tensorflow import load + + file_path = "./my_folder/bert.safetensors" + with open(file_path, "rb") as f: + data = f.read() + + loaded = load(data) + ``` + """ + flat = numpy.load(data) + return _np2tf(flat) + + +def load_file(filename: Union[str, os.PathLike]) -> Dict[str, tf.Tensor]: + """ + Loads a safetensors file into tensorflow format. + + Args: + filename (`str`, or `os.PathLike`)): + The name of the file which contains the tensors + + Returns: + `Dict[str, tf.Tensor]`: dictionary that contains name as key, value as `tf.Tensor` + + Example: + + ```python + from safetensors.tensorflow import load_file + + file_path = "./my_folder/bert.safetensors" + loaded = load_file(file_path) + ``` + """ + result = {} + with safe_open(filename, framework="tf") as f: + for k in f.keys(): + result[k] = f.get_tensor(k) + return result + + +def _np2tf(numpy_dict: Dict[str, np.ndarray]) -> Dict[str, tf.Tensor]: + for k, v in numpy_dict.items(): + numpy_dict[k] = tf.convert_to_tensor(v) + return numpy_dict + + +def _tf2np(tf_dict: Dict[str, tf.Tensor]) -> Dict[str, np.array]: + for k, v in tf_dict.items(): + tf_dict[k] = v.numpy() + return tf_dict diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/safetensors/torch.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/safetensors/torch.py new file mode 100644 index 0000000000000000000000000000000000000000..48532ea5996cd807510b97458a0451f092ea0f35 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/safetensors/torch.py @@ -0,0 +1,503 @@ +import os +import sys +from collections import defaultdict +from typing import Any, Dict, List, Optional, Set, Tuple, Union + +import torch + +from safetensors import deserialize, safe_open, serialize, serialize_file + + +def storage_ptr(tensor: torch.Tensor) -> int: + try: + return tensor.untyped_storage().data_ptr() + except Exception: + # Fallback for torch==1.10 + try: + return tensor.storage().data_ptr() + except NotImplementedError: + # Fallback for meta storage + return 0 + + +def _end_ptr(tensor: torch.Tensor) -> int: + if tensor.nelement(): + stop = tensor.view(-1)[-1].data_ptr() + _SIZE[tensor.dtype] + else: + stop = tensor.data_ptr() + return stop + + +def storage_size(tensor: torch.Tensor) -> int: + try: + return tensor.untyped_storage().nbytes() + except AttributeError: + # Fallback for torch==1.10 + try: + return tensor.storage().size() * _SIZE[tensor.dtype] + except NotImplementedError: + # Fallback for meta storage + # On torch >=2.0 this is the tensor size + return tensor.nelement() * _SIZE[tensor.dtype] + + +def _filter_shared_not_shared(tensors: List[Set[str]], state_dict: Dict[str, torch.Tensor]) -> List[Set[str]]: + filtered_tensors = [] + for shared in tensors: + if len(shared) < 2: + filtered_tensors.append(shared) + continue + + areas = [] + for name in shared: + tensor = state_dict[name] + areas.append((tensor.data_ptr(), _end_ptr(tensor), name)) + areas.sort() + + _, last_stop, last_name = areas[0] + filtered_tensors.append({last_name}) + for start, stop, name in areas[1:]: + if start >= last_stop: + filtered_tensors.append({name}) + else: + filtered_tensors[-1].add(name) + last_stop = stop + + return filtered_tensors + + +def _find_shared_tensors(state_dict: Dict[str, torch.Tensor]) -> List[Set[str]]: + tensors = defaultdict(set) + for k, v in state_dict.items(): + if v.device != torch.device("meta") and storage_ptr(v) != 0 and storage_size(v) != 0: + # Need to add device as key because of multiple GPU. + tensors[(v.device, storage_ptr(v), storage_size(v))].add(k) + tensors = list(sorted(tensors.values())) + tensors = _filter_shared_not_shared(tensors, state_dict) + return tensors + + +def _is_complete(tensor: torch.Tensor) -> bool: + return tensor.data_ptr() == storage_ptr(tensor) and tensor.nelement() * _SIZE[tensor.dtype] == storage_size(tensor) + + +def _remove_duplicate_names( + state_dict: Dict[str, torch.Tensor], + *, + preferred_names: Optional[List[str]] = None, + discard_names: Optional[List[str]] = None, +) -> Dict[str, List[str]]: + if preferred_names is None: + preferred_names = [] + preferred_names = set(preferred_names) + if discard_names is None: + discard_names = [] + discard_names = set(discard_names) + + shareds = _find_shared_tensors(state_dict) + to_remove = defaultdict(list) + for shared in shareds: + complete_names = set([name for name in shared if _is_complete(state_dict[name])]) + if not complete_names: + raise RuntimeError( + "Error while trying to find names to remove to save state dict, but found no suitable name to keep" + f" for saving amongst: {shared}. None is covering the entire storage.Refusing to save/load the model" + " since you could be storing much more memory than needed. Please refer to" + " https://huggingface.co/docs/safetensors/torch_shared_tensors for more information. Or open an" + " issue." + ) + + keep_name = sorted(list(complete_names))[0] + + # Mechanism to preferentially select keys to keep + # coming from the on-disk file to allow + # loading models saved with a different choice + # of keep_name + preferred = complete_names.difference(discard_names) + if preferred: + keep_name = sorted(list(preferred))[0] + + if preferred_names: + preferred = preferred_names.intersection(complete_names) + if preferred: + keep_name = sorted(list(preferred))[0] + for name in sorted(shared): + if name != keep_name: + to_remove[keep_name].append(name) + return to_remove + + +def save_model( + model: torch.nn.Module, filename: str, metadata: Optional[Dict[str, str]] = None, force_contiguous: bool = True +): + """ + Saves a given torch model to specified filename. + This method exists specifically to avoid tensor sharing issues which are + not allowed in `safetensors`. [More information on tensor sharing](../torch_shared_tensors) + + Args: + model (`torch.nn.Module`): + The model to save on disk. + filename (`str`): + The filename location to save the file + metadata (`Dict[str, str]`, *optional*): + Extra information to save along with the file. + Some metadata will be added for each dropped tensors. + This information will not be enough to recover the entire + shared structure but might help understanding things + force_contiguous (`boolean`, *optional*, defaults to True): + Forcing the state_dict to be saved as contiguous tensors. + This has no effect on the correctness of the model, but it + could potentially change performance if the layout of the tensor + was chosen specifically for that reason. + """ + state_dict = model.state_dict() + to_removes = _remove_duplicate_names(state_dict) + + for kept_name, to_remove_group in to_removes.items(): + for to_remove in to_remove_group: + if metadata is None: + metadata = {} + + if to_remove not in metadata: + # Do not override user data + metadata[to_remove] = kept_name + del state_dict[to_remove] + if force_contiguous: + state_dict = {k: v.contiguous() for k, v in state_dict.items()} + try: + save_file(state_dict, filename, metadata=metadata) + except ValueError as e: + msg = str(e) + msg += " Or use save_model(..., force_contiguous=True), read the docs for potential caveats." + raise ValueError(msg) + + +def load_model( + model: torch.nn.Module, filename: Union[str, os.PathLike], strict: bool = True, device: Union[str, int] = "cpu" +) -> Tuple[List[str], List[str]]: + """ + Loads a given filename onto a torch model. + This method exists specifically to avoid tensor sharing issues which are + not allowed in `safetensors`. [More information on tensor sharing](../torch_shared_tensors) + + Args: + model (`torch.nn.Module`): + The model to load onto. + filename (`str`, or `os.PathLike`): + The filename location to load the file from. + strict (`bool`, *optional*, defaults to True): + Whether to fail if you're missing keys or having unexpected ones. + When false, the function simply returns missing and unexpected names. + device (`Union[str, int]`, *optional*, defaults to `cpu`): + The device where the tensors need to be located after load. + available options are all regular torch device locations. + + Returns: + `(missing, unexpected): (List[str], List[str])` + `missing` are names in the model which were not modified during loading + `unexpected` are names that are on the file, but weren't used during + the load. + """ + state_dict = load_file(filename, device=device) + model_state_dict = model.state_dict() + to_removes = _remove_duplicate_names(model_state_dict, preferred_names=state_dict.keys()) + missing, unexpected = model.load_state_dict(state_dict, strict=False) + missing = set(missing) + for to_remove_group in to_removes.values(): + for to_remove in to_remove_group: + if to_remove not in missing: + unexpected.append(to_remove) + else: + missing.remove(to_remove) + if strict and (missing or unexpected): + missing_keys = ", ".join([f'"{k}"' for k in sorted(missing)]) + unexpected_keys = ", ".join([f'"{k}"' for k in sorted(unexpected)]) + error = f"Error(s) in loading state_dict for {model.__class__.__name__}:" + if missing: + error += f"\n Missing key(s) in state_dict: {missing_keys}" + if unexpected: + error += f"\n Unexpected key(s) in state_dict: {unexpected_keys}" + raise RuntimeError(error) + return missing, unexpected + + +def save(tensors: Dict[str, torch.Tensor], metadata: Optional[Dict[str, str]] = None) -> bytes: + """ + Saves a dictionary of tensors into raw bytes in safetensors format. + + Args: + tensors (`Dict[str, torch.Tensor]`): + The incoming tensors. Tensors need to be contiguous and dense. + metadata (`Dict[str, str]`, *optional*, defaults to `None`): + Optional text only metadata you might want to save in your header. + For instance it can be useful to specify more about the underlying + tensors. This is purely informative and does not affect tensor loading. + + Returns: + `bytes`: The raw bytes representing the format + + Example: + + ```python + from safetensors.torch import save + import torch + + tensors = {"embedding": torch.zeros((512, 1024)), "attention": torch.zeros((256, 256))} + byte_data = save(tensors) + ``` + """ + serialized = serialize(_flatten(tensors), metadata=metadata) + result = bytes(serialized) + return result + + +def save_file( + tensors: Dict[str, torch.Tensor], + filename: Union[str, os.PathLike], + metadata: Optional[Dict[str, str]] = None, +): + """ + Saves a dictionary of tensors into raw bytes in safetensors format. + + Args: + tensors (`Dict[str, torch.Tensor]`): + The incoming tensors. Tensors need to be contiguous and dense. + filename (`str`, or `os.PathLike`)): + The filename we're saving into. + metadata (`Dict[str, str]`, *optional*, defaults to `None`): + Optional text only metadata you might want to save in your header. + For instance it can be useful to specify more about the underlying + tensors. This is purely informative and does not affect tensor loading. + + Returns: + `None` + + Example: + + ```python + from safetensors.torch import save_file + import torch + + tensors = {"embedding": torch.zeros((512, 1024)), "attention": torch.zeros((256, 256))} + save_file(tensors, "model.safetensors") + ``` + """ + serialize_file(_flatten(tensors), filename, metadata=metadata) + + +def load_file(filename: Union[str, os.PathLike], device: Union[str, int] = "cpu") -> Dict[str, torch.Tensor]: + """ + Loads a safetensors file into torch format. + + Args: + filename (`str`, or `os.PathLike`): + The name of the file which contains the tensors + device (`Union[str, int]`, *optional*, defaults to `cpu`): + The device where the tensors need to be located after load. + available options are all regular torch device locations. + + Returns: + `Dict[str, torch.Tensor]`: dictionary that contains name as key, value as `torch.Tensor` + + Example: + + ```python + from safetensors.torch import load_file + + file_path = "./my_folder/bert.safetensors" + loaded = load_file(file_path) + ``` + """ + result = {} + with safe_open(filename, framework="pt", device=device) as f: + for k in f.keys(): + result[k] = f.get_tensor(k) + return result + + +def load(data: bytes) -> Dict[str, torch.Tensor]: + """ + Loads a safetensors file into torch format from pure bytes. + + Args: + data (`bytes`): + The content of a safetensors file + + Returns: + `Dict[str, torch.Tensor]`: dictionary that contains name as key, value as `torch.Tensor` on cpu + + Example: + + ```python + from safetensors.torch import load + + file_path = "./my_folder/bert.safetensors" + with open(file_path, "rb") as f: + data = f.read() + + loaded = load(data) + ``` + """ + flat = deserialize(data) + return _view2torch(flat) + + +# torch.float8 formats require 2.1; we do not support these dtypes on earlier versions +_float8_e4m3fn = getattr(torch, "float8_e4m3fn", None) +_float8_e5m2 = getattr(torch, "float8_e5m2", None) + +_SIZE = { + torch.int64: 8, + torch.float32: 4, + torch.int32: 4, + torch.bfloat16: 2, + torch.float16: 2, + torch.int16: 2, + torch.uint8: 1, + torch.int8: 1, + torch.bool: 1, + torch.float64: 8, + _float8_e4m3fn: 1, + _float8_e5m2: 1, +} + +_TYPES = { + "F64": torch.float64, + "F32": torch.float32, + "F16": torch.float16, + "BF16": torch.bfloat16, + "I64": torch.int64, + # "U64": torch.uint64, + "I32": torch.int32, + # "U32": torch.uint32, + "I16": torch.int16, + # "U16": torch.uint16, + "I8": torch.int8, + "U8": torch.uint8, + "BOOL": torch.bool, + "F8_E4M3": _float8_e4m3fn, + "F8_E5M2": _float8_e5m2, +} + + +def _getdtype(dtype_str: str) -> torch.dtype: + return _TYPES[dtype_str] + + +def _view2torch(safeview) -> Dict[str, torch.Tensor]: + result = {} + for k, v in safeview: + dtype = _getdtype(v["dtype"]) + if len(v["data"]) == 0: + # Workaround because frombuffer doesn't accept zero-size tensors + assert any(x == 0 for x in v["shape"]) + arr = torch.empty(v["shape"], dtype=dtype) + else: + arr = torch.frombuffer(v["data"], dtype=dtype).reshape(v["shape"]) + if sys.byteorder == "big": + arr = torch.from_numpy(arr.numpy().byteswap(inplace=False)) + result[k] = arr + + return result + + +def _tobytes(tensor: torch.Tensor, name: str) -> bytes: + if tensor.layout != torch.strided: + raise ValueError( + f"You are trying to save a sparse tensor: `{name}` which this library does not support." + " You can make it a dense tensor before saving with `.to_dense()` but be aware this might" + " make a much larger file than needed." + ) + + if not tensor.is_contiguous(): + raise ValueError( + f"You are trying to save a non contiguous tensor: `{name}` which is not allowed. It either means you" + " are trying to save tensors which are reference of each other in which case it's recommended to save" + " only the full tensors, and reslice at load time, or simply call `.contiguous()` on your tensor to" + " pack it before saving." + ) + if tensor.device.type != "cpu": + # Moving tensor to cpu before saving + tensor = tensor.to("cpu") + + import ctypes + + import numpy as np + + # When shape is empty (scalar), np.prod returns a float + # we need a int for the following calculations + length = int(np.prod(tensor.shape).item()) + bytes_per_item = _SIZE[tensor.dtype] + + total_bytes = length * bytes_per_item + + ptr = tensor.data_ptr() + if ptr == 0: + return b"" + newptr = ctypes.cast(ptr, ctypes.POINTER(ctypes.c_ubyte)) + data = np.ctypeslib.as_array(newptr, (total_bytes,)) # no internal copy + if sys.byteorder == "big": + NPDTYPES = { + torch.int64: np.int64, + torch.float32: np.float32, + torch.int32: np.int32, + # XXX: This is ok because both have the same width + torch.bfloat16: np.float16, + torch.float16: np.float16, + torch.int16: np.int16, + torch.uint8: np.uint8, + torch.int8: np.int8, + torch.bool: bool, + torch.float64: np.float64, + # XXX: This is ok because both have the same width and byteswap is a no-op anyway + _float8_e4m3fn: np.uint8, + _float8_e5m2: np.uint8, + } + npdtype = NPDTYPES[tensor.dtype] + # Not in place as that would potentially modify a live running model + data = data.view(npdtype).byteswap(inplace=False) + return data.tobytes() + + +def _flatten(tensors: Dict[str, torch.Tensor]) -> Dict[str, Dict[str, Any]]: + if not isinstance(tensors, dict): + raise ValueError(f"Expected a dict of [str, torch.Tensor] but received {type(tensors)}") + + invalid_tensors = [] + for k, v in tensors.items(): + if not isinstance(v, torch.Tensor): + raise ValueError(f"Key `{k}` is invalid, expected torch.Tensor but received {type(v)}") + + if v.layout != torch.strided: + invalid_tensors.append(k) + if invalid_tensors: + raise ValueError( + f"You are trying to save a sparse tensors: `{invalid_tensors}` which this library does not support." + " You can make it a dense tensor before saving with `.to_dense()` but be aware this might" + " make a much larger file than needed." + ) + + shared_pointers = _find_shared_tensors(tensors) + failing = [] + for names in shared_pointers: + if len(names) > 1: + failing.append(names) + + if failing: + raise RuntimeError( + f""" + Some tensors share memory, this will lead to duplicate memory on disk and potential differences when loading them again: {failing}. + A potential way to correctly save your model is to use `save_model`. + More information at https://huggingface.co/docs/safetensors/torch_shared_tensors + """ + ) + + return { + k: { + "dtype": str(v.dtype).split(".")[-1], + "shape": v.shape, + "data": _tobytes(v, k), + } + for k, v in tensors.items() + } diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_sources.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_sources.py new file mode 100644 index 0000000000000000000000000000000000000000..e0ab883a8b46ced06b57bd4dc809861ae4c77af4 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_sources.py @@ -0,0 +1,138 @@ +# mypy: allow-untyped-defs +import ast +import functools +import inspect +from textwrap import dedent +from typing import Any, NamedTuple + +from torch._C import ErrorReport +from torch._C._jit_tree_views import SourceRangeFactory + + +def get_source_lines_and_file( + obj: Any, + error_msg: str | None = None, +) -> tuple[list[str], int, str | None]: + """ + Wrapper around inspect.getsourcelines and inspect.getsourcefile. + + Returns: (sourcelines, file_lino, filename) + """ + filename = None # in case getsourcefile throws + try: + filename = inspect.getsourcefile(obj) + sourcelines, file_lineno = inspect.getsourcelines(obj) + except OSError as e: + msg = ( + f"Can't get source for {obj}. TorchScript requires source access in " + "order to carry out compilation, make sure original .py files are " + "available." + ) + if error_msg: + msg += "\n" + error_msg + raise OSError(msg) from e + + return sourcelines, file_lineno, filename + + +def normalize_source_lines(sourcelines: list[str]) -> list[str]: + """ + This helper function accepts a list of source lines. It finds the + indentation level of the function definition (`def`), then it indents + all lines in the function body to a point at or greater than that + level. This allows for comments and continued string literals that + are at a lower indentation than the rest of the code. + Args: + sourcelines: function source code, separated into lines by + the '\n' character + Returns: + A list of source lines that have been correctly aligned + """ + + def remove_prefix(text, prefix): + return text[text.startswith(prefix) and len(prefix) :] + + # Find the line and line number containing the function definition + idx = None + for i, l in enumerate(sourcelines): + if l.lstrip().startswith("def"): + idx = i + break + + # This will happen when the function is a lambda- we won't find "def" anywhere in the source + # lines in that case. Currently trying to JIT compile a lambda will throw an error up in + # `parse_def()`, but we might want to handle this case in the future. + if idx is None: + return sourcelines + + # Get a string representing the amount of leading whitespace + fn_def = sourcelines[idx] + whitespace = fn_def.split("def")[0] + + # Add this leading whitespace to all lines before and after the `def` + aligned_prefix = [ + whitespace + remove_prefix(s, whitespace) for s in sourcelines[:idx] + ] + aligned_suffix = [ + whitespace + remove_prefix(s, whitespace) for s in sourcelines[idx + 1 :] + ] + + # Put it together again + aligned_prefix.append(fn_def) + return aligned_prefix + aligned_suffix + + +# Thin wrapper around SourceRangeFactory to store extra metadata +# about the function-to-be-compiled. +class SourceContext(SourceRangeFactory): + def __init__( + self, + source, + filename, + file_lineno, + leading_whitespace_len, + uses_true_division=True, + funcname=None, + ): + super().__init__(source, filename, file_lineno, leading_whitespace_len) + self.uses_true_division = uses_true_division + self.filename = filename + self.funcname = funcname + + +@functools.cache +def make_source_context(*args): + return SourceContext(*args) + + +def fake_range(): + return SourceContext("", None, 0, 0).make_raw_range(0, 1) + + +class ParsedDef(NamedTuple): + ast: ast.Module + ctx: SourceContext + source: str + filename: str | None + file_lineno: int + + +def parse_def(fn): + sourcelines, file_lineno, filename = get_source_lines_and_file( + fn, ErrorReport.call_stack() + ) + sourcelines = normalize_source_lines(sourcelines) + source = "".join(sourcelines) + dedent_src = dedent(source) + py_ast = ast.parse(dedent_src) + if len(py_ast.body) != 1 or not isinstance(py_ast.body[0], ast.FunctionDef): + raise RuntimeError( + f"Expected a single top-level function: {filename}:{file_lineno}" + ) + leading_whitespace_len = len(source.split("\n", 1)[0]) - len( + dedent_src.split("\n", 1)[0] + ) + ctx = make_source_context( + source, filename, file_lineno, leading_whitespace_len, True, fn.__name__ + ) + return ParsedDef(py_ast, ctx, source, filename, file_lineno) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/library.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/library.py new file mode 100644 index 0000000000000000000000000000000000000000..5305d647bc6136888b0bb476e7d4026579428a2e --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/library.py @@ -0,0 +1,1736 @@ +# mypy: allow-untyped-defs +import contextlib +import functools +import inspect +import re +import sys +import traceback +import weakref +from collections.abc import Callable, Sequence +from typing import Any, overload, TYPE_CHECKING, TypeVar, Union +from typing_extensions import deprecated, ParamSpec + +import torch +import torch._library as _library +from torch._library.custom_ops import ( + _cast, + _maybe_get_opdef, + custom_op, + CustomOpDef, + device_types_t, +) +from torch._library.effects import EffectType +from torch._library.infer_schema import infer_schema # noqa: F401 +from torch._library.triton import triton_op, wrap_triton +from torch._ops import OpOverload +from torch.types import _dtype + + +__all__ = [ + "Library", + "impl", + "define", + "fallthrough_kernel", + "impl_abstract", + "register_autocast", + "register_fake", + "register_torch_dispatch", + "register_vmap", + "get_ctx", + "get_kernel", + "custom_op", + "triton_op", + "wrap_triton", + "infer_schema", +] + +_T = TypeVar("_T") +_P = ParamSpec("_P") + +# Set containing the combination of (namespace, operator, DispatchKey) for which a new kernel has been registered +# The keys in the set are of the form `namespace + "/" + op_name + "/" + dispatch_key`. +# This set is maintained to ensure that two libraries don't try to override the exact same functionality to avoid +# libraries calling into kernels not intended to be called. +_impls: set[str] = set() +_defs: set[str] = set() + +# prim is reserved by TorchScript interpreter +_reserved_namespaces = ["prim"] + + +def fallthrough_kernel(): + """ + A dummy function to pass to ``Library.impl`` in order to register a fallthrough. + """ + raise NotImplementedError("fallthrough_kernel() should never be called.") + + +class Library: + """ + A class to create libraries that can be used to register new operators or + override operators in existing libraries from Python. + A user can optionally pass in a dispatch keyname if they only want to register + kernels corresponding to only one specific dispatch key. + + To create a library to override operators in an existing library (with name ns), set the kind to "IMPL". + To create a new library (with name ns) to register new operators, set the kind to "DEF". + To create a fragment of a possibly existing library to register operators (and bypass + the limitation that there is only one library for a given namespace), set the kind to + "FRAGMENT". + + Args: + ns: library name + kind: "DEF", "IMPL", "FRAGMENT" + dispatch_key: PyTorch dispatch key (default: "") + """ + + def __init__(self, ns, kind, dispatch_key=""): + from torch.fx.operator_schemas import _SCHEMA_TO_SIGNATURE_CACHE + + if kind not in ("IMPL", "DEF", "FRAGMENT"): + raise ValueError("Unsupported kind: ", kind) + + if ns in _reserved_namespaces and (kind == "DEF" or kind == "FRAGMENT"): + raise ValueError( + ns, + " is a reserved namespace. Please try creating a library with another name.", + ) + + frame = traceback.extract_stack(limit=2)[0] + filename, lineno = frame.filename, frame.lineno + self.m: Any | None = torch._C._dispatch_library( + kind, ns, dispatch_key, filename, lineno + ) + self.ns = ns + self._op_defs: set[str] = set() + self._op_impls: set[str] = set() + self._registration_handles: list[torch._library.utils.RegistrationHandle] = [] + self.kind = kind + self.dispatch_key = dispatch_key + # Use a finalizer to setup the "destructor" instead of __del__. + # Python __del__ can lead to weird things (globals and locals may already + # be gone when __del__ actually gets called!). finalizers help the + # situation because it lets us capture references and keeps them alive + weakref.finalize( + self, + _del_library, + _impls, + self._op_impls, + _defs, + self._op_defs, + self._registration_handles, + self.m, + _SCHEMA_TO_SIGNATURE_CACHE, + ) + + def __repr__(self): + return f"Library(kind={self.kind}, ns={self.ns}, dispatch_key={self.dispatch_key})>" + + def define(self, schema, alias_analysis="", *, tags=()): + r"""Defines a new operator and its semantics in the ns namespace. + + Args: + schema: function schema to define a new operator. + alias_analysis (optional): Indicates if the aliasing properties of the operator arguments can be + inferred from the schema (default behavior) or not ("CONSERVATIVE"). + tags (Tag | Sequence[Tag]): one or more torch.Tag to apply to this + operator. Tagging an operator changes the operator's behavior + under various PyTorch subsystems; please read the docs for the + torch.Tag carefully before applying it. + + Returns: + name of the operator as inferred from the schema. + + Example:: + + >>> my_lib = Library("mylib", "DEF") + >>> my_lib.define("sum(Tensor self) -> Tensor") + """ + + # This is added because we also want to disallow PURE_FUNCTION alias analysis which is a valid + # AliasAnalysis type in C++ + if alias_analysis not in ["", "FROM_SCHEMA", "CONSERVATIVE"]: + raise RuntimeError(f"Invalid alias_analysis type {alias_analysis}") + assert self.m is not None + if isinstance(tags, torch.Tag): + tags = (tags,) + + name = schema.split("(")[0] + packet_name = name.split(".")[0] if "." in name else name + has_preexisting_packet = hasattr(torch.ops, self.ns) and hasattr( + getattr(torch.ops, self.ns), packet_name + ) + + result = self.m.define(schema, alias_analysis, tuple(tags)) + name = schema.split("(")[0] + qualname = self.ns + "::" + name + + # If the OpOverloadPacket exists already, then this means we're adding a + # new OpOverload for it. Refresh the packet to include the new OpOverload. + if has_preexisting_packet: + ns = getattr(torch.ops, self.ns) + packet = getattr(ns, packet_name) + torch._ops._refresh_packet(packet) + + self._op_defs.add(qualname) + _defs.add(qualname) + return result + + def _register_fake(self, op_name, fn, _stacklevel=1, *, allow_override=False): + r"""Registers the fake impl for an operator defined in the library.""" + + source = torch._library.utils.get_source(_stacklevel + 1) + frame = sys._getframe(_stacklevel) + caller_module = inspect.getmodule(frame) + # Can be none if you call register_fake from somewhere there isn't a module + # (e.g. __main__) + caller_module_name = None if caller_module is None else caller_module.__name__ + + # TODO(rzou): We're gonna need to stage this change with torchvision, + # since torchvision is github first. + if caller_module_name is not None and caller_module_name.startswith( + "torchvision." + ): + caller_module_name = None + + qualname = f"{self.ns}::{op_name}" + entry = torch._library.simple_registry.singleton.find(qualname) + if caller_module_name is not None: + func_to_register = _check_pystubs_once(fn, qualname, caller_module_name) + else: + func_to_register = fn + + handle = entry.fake_impl.register( + func_to_register, source, lib=self, allow_override=allow_override + ) + self._registration_handles.append(handle) + + def _register_torch_dispatch_rule(self, op_name, torch_dispatch_class, fn): + r"""Registers a torch_dispatch rule for the given operator and torch_dispatch_class. + + This allows for open registration to specify the behavior between the operator + and the torch_dispatch_class without needing to modify the torch_dispatch_class + or the operator directly. + + The torch_dispatch_class is either a Tensor subclass with `__torch_dispatch__` or a + TorchDispatchMode. + + If it is a Tensor subclass, we expect fn to have the following signature: + (cls, func: OpOverload, types: Tuple[type, ...], args, kwargs) -> Any + + If it is a TorchDispatchMode, we expect fn to have the following signature: + (mode, func: OpOverload, types: Tuple[type, ...], args, kwargs) -> Any + """ + + qualname = f"{self.ns}::{op_name}" + entry = torch._library.simple_registry.singleton.find(qualname) + handle = entry.torch_dispatch_rules.register(torch_dispatch_class, fn) + self._registration_handles.append(handle) + + def _impl_with_aoti_compile(self, op_name, dispatch_key=""): + r"""Register the operator to use the AOTI-compiled implementation. + + Args: + op_name: operator name (along with the overload) or OpOverload object. + dispatch_key: dispatch key that the input function should be registered for. By default, it uses + the dispatch key that the library was created with. + + Example:: + + >>> my_lib = Library("aten", "IMPL") + >>> my_lib._impl_with_aoti_compile("div.Tensor", "CPU") + """ + + if dispatch_key == "": + dispatch_key = self.dispatch_key + # pyrefly: ignore [bad-argument-type] + assert torch.DispatchKeySet(dispatch_key).has(torch._C.DispatchKey.Dense) + + if isinstance(op_name, str): + name = op_name + elif isinstance(op_name, OpOverload): + name = op_name._schema.name + overload_name = op_name._schema.overload_name + if overload_name != "": + name = name + "." + overload_name + else: + raise RuntimeError( + "_impl_with_aoti_compile should be passed either a name or an OpOverload object " + "as the first argument" + ) + + key = self.ns + "/" + name.split("::")[-1] + "/" + dispatch_key + if key in _impls: + # TODO: in future, add more info about where the existing function is registered (this info is + # today already returned by the C++ warning when _impl_with_aoti_compile is called but we error out before that) + raise RuntimeError( + "This is not allowed since there's already a kernel registered from python overriding {}" + "'s behavior for {} dispatch key and {} namespace.".format( + name.split("::")[-1], dispatch_key, self.ns + ) + ) + + assert self.m is not None + impl_fn: Callable = self.m.impl_with_aoti_compile + impl_fn(self.ns, name.split("::")[-1], dispatch_key) + + _impls.add(key) + self._op_impls.add(key) + + def impl( + self, op_name, fn, dispatch_key="", *, with_keyset=False, allow_override=False + ): + r"""Registers the function implementation for an operator defined in the library. + + Args: + op_name: operator name (along with the overload) or OpOverload object. + fn: function that's the operator implementation for the input dispatch key or :func:`~fallthrough_kernel` + to register a fallthrough. + dispatch_key: dispatch key that the input function should be registered for. By default, it uses + the dispatch key that the library was created with. + with_keyset: flag controlling if the current dispatcher call keyset should be passed as the first argument + to :attr:`fn` when calling. This should be used to create the appropriate keyset for redispatch calls. + allow_override: Flag controlling if we want to override an + existing registered kernel implementation. This is by + default off, and will error you're trying to register a + kernel to a dispatch key with a kernel already + registered. + + Example:: + + >>> my_lib = Library("aten", "IMPL") + >>> def div_cpu(self, other): + >>> return self * (1 / other) + >>> my_lib.impl("div.Tensor", div_cpu, "CPU") + """ + + if not callable(fn): + raise TypeError( + f"Input function is required to be a callable but found type {type(fn)}" + ) + if dispatch_key == "": + dispatch_key = self.dispatch_key + + if isinstance(op_name, str): + name = op_name + elif isinstance(op_name, OpOverload): + name = op_name._schema.name + overload_name = op_name._schema.overload_name + if overload_name != "": + name = name + "." + overload_name + else: + raise RuntimeError( + "impl should be passed either a name or an OpOverload object as the first argument" + ) + + key = self.ns + "/" + name.split("::")[-1] + "/" + dispatch_key + if (not allow_override) and key in _impls: + # TODO: in future, add more info about where the existing function is registered (this info is + # today already returned by the C++ warning when impl is called but we error out before that) + raise RuntimeError( + "This is not allowed since there's already a kernel registered from python overriding {}" + "'s behavior for {} dispatch key and {} namespace.".format( + name.split("::")[-1], dispatch_key, self.ns + ) + ) + + if dispatch_key == "Meta": + dispatcher_op_name = name + if "::" not in dispatcher_op_name: + dispatcher_op_name = f"{self.ns}::{dispatcher_op_name}" + + # Internally, we shouldn't be registering meta kernels for any operators that + # have CompositeImplicitAutograd kernels. + # Instead, we should be letting those decompositions run, and writing meta kernels + # only for the base operators. + if torch._C._dispatch_has_kernel_for_dispatch_key( + dispatcher_op_name, "CompositeImplicitAutograd" + ): + raise RuntimeError( + f"We should not register a meta kernel directly to the operator '{name}'," + " because it has a CompositeImplicitAutograd kernel in core." + " Instead we should let the operator decompose, and ensure that we have meta kernels" + " for the base ops that it decomposes into." + ) + + assert self.m is not None + self.m.impl( + name, + dispatch_key if dispatch_key != "" else "CompositeImplicitAutograd", + fn, + with_keyset, + ) + + _impls.add(key) + self._op_impls.add(key) + + def fallback(self, fn, dispatch_key="", *, with_keyset=False): + r"""Registers the function implementation as the fallback for the given key. + + This function only works for a library with global namespace ("_"). + + Args: + fn: function used as fallback for the given dispatch key or :func:`~fallthrough_kernel` + to register a fallthrough. + dispatch_key: dispatch key that the input function should be registered for. By default, it uses + the dispatch key that the library was created with. + with_keyset: flag controlling if the current dispatcher call keyset should be passed as the first argument + to :attr:`fn` when calling. This should be used to create the appropriate keyset for redispatch calls. + + Example:: + + >>> my_lib = Library("_", "IMPL") + >>> def fallback_kernel(op, *args, **kwargs): + >>> # Handle all autocast ops generically + >>> # ... + >>> my_lib.fallback(fallback_kernel, "Autocast") + """ + + if dispatch_key == "": + dispatch_key = self.dispatch_key + + if self.ns != "_": + raise RuntimeError( + f"""Fallback can only be registered using library fragment on the global namespace "_" but it is {self.ns}""" + ) + + assert dispatch_key != "" + assert self.m is not None + + self.m.fallback(dispatch_key, fn, with_keyset) + + def _register_effectful_op(self, op_name: str, effect: EffectType | None): + """ + Registers an effect to an operator. This is used to register an op that + has side effects that is not capturable by the schema. + + Args: + op_name: operator name (along with the overload) or OpOverload object. + effect: The effect of the op. + """ + from torch._higher_order_ops.effects import ( + _register_effectful_op as hoo_register_effect, + ) + + handle = hoo_register_effect(op_name, effect) + self._registration_handles.append(handle) + + def _destroy(self): + if self.m is not None: + self.m.reset() + self.m = None + for handle in self._registration_handles: + handle.destroy() + self._registration_handles.clear() + global _impls + _impls -= self._op_impls + for name in self._op_defs: + # Delete the cached torch.ops.ns.foo if it was registered. + # Otherwise, accessing it leads to a segfault. + # It's possible that we only registered an overload in this Library + # and another library owns an alive overload. + # That's OK - the next time torch.ops.ns.foo gets called, it'll be + # recomputed to point at the right collection of overloads. + ns, name_with_overload = name.split("::") + name = name_with_overload.split(".")[0] + if not hasattr(torch.ops, ns): + continue + namespace = getattr(torch.ops, ns) + if not hasattr(namespace, name): + continue + delattr(namespace, name) + namespace._dir.remove(name) + + +def _del_library( + captured_impls, + op_impls, + captured_defs, + op_defs, + registration_handles, + m, + schema_to_signature_cache, +): + for op_def in op_defs: + name = op_def + overload_name = "" + if "." in op_def: + name, overload_name = op_def.split(".") + if ( + name, + overload_name, + ) in schema_to_signature_cache: + del schema_to_signature_cache[(name, overload_name)] + + captured_impls -= op_impls + captured_defs -= op_defs + for handle in registration_handles: + handle.destroy() + + if m is not None: + m.reset() + + +@contextlib.contextmanager +def _scoped_library(*args, **kwargs): + try: + lib = Library(*args, **kwargs) + yield lib + finally: + lib._destroy() + + +_keep_alive: list[Library] = [] + + +NAMELESS_SCHEMA = re.compile(r"\(.*\) -> .*") + + +@functools.singledispatch +def define(qualname, schema, *, lib=None, tags=()): + r"""Defines a new operator. + + In PyTorch, defining an op (short for "operator") is a two step-process: + - we need to define the op (by providing an operator name and schema) + - we need to implement behavior for how the operator interacts with + various PyTorch subsystems, like CPU/CUDA Tensors, Autograd, etc. + + This entrypoint defines the custom operator (the first step) + you must then perform the second step by calling various + ``impl_*`` APIs, like :func:`torch.library.impl` or + :func:`torch.library.register_fake`. + + Args: + qualname (str): The qualified name for the operator. Should be + a string that looks like "namespace::name", e.g. "aten::sin". + Operators in PyTorch need a namespace to + avoid name collisions; a given operator may only be created once. + If you are writing a Python library, we recommend the namespace to + be the name of your top-level module. + schema (str): The schema of the operator. E.g. "(Tensor x) -> Tensor" + for an op that accepts one Tensor and returns one Tensor. It does + not contain the operator name (that is passed in ``qualname``). + lib (Optional[Library]): If provided, the lifetime of this operator + will be tied to the lifetime of the Library object. + tags (Tag | Sequence[Tag]): one or more torch.Tag to apply to this + operator. Tagging an operator changes the operator's behavior + under various PyTorch subsystems; please read the docs for the + torch.Tag carefully before applying it. + + Example:: + >>> import torch + >>> import numpy as np + >>> + >>> # Define the operator + >>> torch.library.define("mylib::sin", "(Tensor x) -> Tensor") + >>> + >>> # Add implementations for the operator + >>> @torch.library.impl("mylib::sin", "cpu") + >>> def f(x): + >>> return torch.from_numpy(np.sin(x.numpy())) + >>> + >>> # Call the new operator from torch.ops. + >>> x = torch.randn(3) + >>> y = torch.ops.mylib.sin(x) + >>> assert torch.allclose(y, x.sin()) + + """ + if not isinstance(qualname, str): + raise ValueError( + f"define(qualname, schema): expected qualname " + f"to be instance of str, got {type(qualname)}" + ) + namespace, name = torch._library.utils.parse_namespace(qualname) + if lib is None: + lib = Library(namespace, "FRAGMENT") + _keep_alive.append(lib) + if not NAMELESS_SCHEMA.fullmatch(schema): + raise ValueError( + f"define(qualname, schema, ...): expected schema " + f'to look like e.g. "(Tensor x) -> Tensor" but ' + f'got "{schema}"' + ) + lib.define(name + schema, alias_analysis="", tags=tags) + + +@define.register +def _(lib: Library, schema, alias_analysis=""): + """The old torch.library.define. + We're keeping this around for BC reasons + """ + + def wrap(f): + name = lib.define(schema, alias_analysis) + lib.impl(name, f) + return f + + return wrap + + +@overload +def impl( + qualname: str, + types: str | Sequence[str], + func: None = None, + *, + lib: Library | None = None, +) -> Callable[[Callable[..., object]], None]: ... + + +@overload +def impl( + qualname: str, + types: str | Sequence[str], + func: Callable[..., object], + *, + lib: Library | None = None, +) -> None: ... + + +# Deprecated BC API +@overload +def impl( + lib: Library, + name: str, + dispatch_key: str = "", +) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: ... + + +@functools.singledispatch +def impl( + qualname: str, + types: str | Sequence[str], + func: Callable[_P, _T] | None = None, + *, + lib: Library | None = None, +) -> object: + """Register an implementation for a device type for this operator. + + You may pass "default" for ``types`` to register this implementation as the + default implementation for ALL device types. + Please only use this if the implementation truly supports all device types; + for example, this is true if it is a composition of built-in PyTorch operators. + + This API may be used as a decorator. You can use nested decorators + with this API provided they return a function and are placed inside + this API (see Example 2). + + Some valid types are: "cpu", "cuda", "xla", "mps", "ipu", "xpu". + + Args: + qualname (str): Should be a string that looks like "namespace::operator_name". + types (str | Sequence[str]): The device types to register an impl to. + lib (Optional[Library]): If provided, the lifetime of this registration + will be tied to the lifetime of the Library object. + + Examples: + >>> import torch + >>> import numpy as np + >>> # Example 1: Register function. + >>> # Define the operator + >>> torch.library.define("mylib::mysin", "(Tensor x) -> Tensor") + >>> + >>> # Add implementations for the cpu device + >>> @torch.library.impl("mylib::mysin", "cpu") + >>> def f(x): + >>> return torch.from_numpy(np.sin(x.numpy())) + >>> + >>> x = torch.randn(3) + >>> y = torch.ops.mylib.mysin(x) + >>> assert torch.allclose(y, x.sin()) + >>> + >>> # Example 2: Register function with decorator. + >>> def custom_decorator(func): + >>> def wrapper(*args, **kwargs): + >>> return func(*args, **kwargs) + 1 + >>> return wrapper + >>> + >>> # Define the operator + >>> torch.library.define("mylib::sin_plus_one", "(Tensor x) -> Tensor") + >>> + >>> # Add implementations for the operator + >>> @torch.library.impl("mylib::sin_plus_one", "cpu") + >>> @custom_decorator + >>> def f(x): + >>> return torch.from_numpy(np.sin(x.numpy())) + >>> + >>> # Call the new operator from torch.ops. + >>> x = torch.randn(3) + >>> + >>> y1 = torch.ops.mylib.sin_plus_one(x) + >>> y2 = torch.sin(x) + 1 + >>> assert torch.allclose(y1, y2) + """ + + return _impl(qualname, types, func, lib=lib, disable_dynamo=False) + + +if not TYPE_CHECKING: + + @impl.register + def _( + lib: Library, name: str, dispatch_key: str = "" + ) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: + """Legacy torch.library.impl API. Kept around for BC""" + + def wrap(f: Callable[_P, _T]) -> Callable[_P, _T]: + lib.impl(name, f, dispatch_key) + return f + + return wrap + + +@overload +def _impl( + qualname: str, + types: str | Sequence[str], + func: None = None, + *, + lib: Library | None = None, + disable_dynamo: bool = False, +) -> Callable[[Callable[..., object]], None]: ... + + +@overload +def _impl( + qualname: str, + types: str | Sequence[str], + func: Callable[..., object], + *, + lib: Library | None = None, + disable_dynamo: bool = False, +) -> None: ... + + +def _impl( + qualname: str, + types: str | Sequence[str], + func: Callable[..., object] | None = None, + *, + lib: Library | None = None, + disable_dynamo: bool = False, +) -> Callable[[Callable[..., object]], None] | None: + # See impl() + if isinstance(types, str): + types = (types,) + keys = set({}) + for typ in types: + is_dispatch_key = torch._C._parse_dispatch_key(typ) + if is_dispatch_key: + # We also support passing a DispatchKey to impl. Please prefer using + # the higher-level torch.library APIs and only pass DispatchKey to + # torch.library.impl with caution (or even better, don't use this + # option and file an issue on GitHub for what you need). + # We don't advertise this to users because + # it is very easy to shoot yourself in the foot. + keys.add(typ) + else: + keys.add(_device_type_to_key(typ)) + + def register_(func: Callable[..., object]) -> None: + namespace, _ = torch._library.utils.parse_namespace(qualname) + + if lib is None: + use_lib = Library(namespace, "FRAGMENT") + _keep_alive.append(use_lib) + else: + use_lib = lib + if disable_dynamo: + + @torch._disable_dynamo + def func_no_dynamo(*args, **kwargs): + return func(*args, **kwargs) + + for key in keys: + use_lib.impl(qualname, func_no_dynamo, key) + else: + for key in keys: + use_lib.impl(qualname, func, key) + + if func is None: + return register_ + else: + register_(func) + return None + + +def _device_type_to_key(device_type: str) -> str: + if device_type == "default": + # This is technically not correct, because although all device_type + # DispatchKeys are included in CompositeExplicitAutograd, + # not everything in CompositeExplicitAutograd is associated with a + # device_type. I don't really care that much about the difference. + return "CompositeExplicitAutograd" + return torch._C._dispatch_key_for_device(device_type) + + +@deprecated( + "`torch.library.impl_abstract` was renamed to `torch.library.register_fake`. Please use that " + "instead; we will remove `torch.library.impl_abstract` in a future version of PyTorch.", + category=FutureWarning, +) +def impl_abstract(qualname, func=None, *, lib=None, _stacklevel=1): + r"""This API was renamed to :func:`torch.library.register_fake` in PyTorch 2.4. + Please use that instead. + """ + if func is not None: + _stacklevel = _stacklevel + 1 + return register_fake(qualname, func, lib=lib, _stacklevel=_stacklevel) + + +_op_identifier = Union[ + str, "torch._ops.OpOverload", "torch._library.custom_ops.CustomOpDef" +] + + +def register_kernel( + op: _op_identifier, + device_types: device_types_t, + func: Callable | None = None, + /, + *, + lib: Library | None = None, +): + """Register an implementation for a device type for this operator. + + Some valid device_types are: "cpu", "cuda", "xla", "mps", "ipu", "xpu". + This API may be used as a decorator. + + Args: + op (str | OpOverload): The operator to register an impl to. + device_types (None | str | Sequence[str]): The device_types to register an impl to. + If None, we will register to all device types -- please only use + this option if your implementation is truly device-type-agnostic. + func (Callable): The function to register as the implementation for + the given device types. + lib (Optional[Library]): If provided, the lifetime of this registration + + Examples:: + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) + >>> import torch + >>> from torch import Tensor + >>> from torch.library import custom_op + >>> import numpy as np + >>> + >>> # Create a custom op that works on cpu + >>> @custom_op("mylib::numpy_sin", mutates_args=(), device_types="cpu") + >>> def numpy_sin(x: Tensor) -> Tensor: + >>> x_np = x.numpy() + >>> y_np = np.sin(x_np) + >>> return torch.from_numpy(y_np) + >>> + >>> # Add implementations for the cuda device + >>> @torch.library.register_kernel("mylib::numpy_sin", "cuda") + >>> def _(x): + >>> x_np = x.cpu().numpy() + >>> y_np = np.sin(x_np) + >>> return torch.from_numpy(y_np).to(device=x.device) + >>> + >>> x_cpu = torch.randn(3) + >>> x_cuda = x_cpu.cuda() + >>> assert torch.allclose(numpy_sin(x_cpu), x_cpu.sin()) + >>> assert torch.allclose(numpy_sin(x_cuda), x_cuda.sin()) + + """ + + if not isinstance( + op, (str, torch._ops.OpOverload, torch._library.custom_ops.CustomOpDef) + ): + raise ValueError( + f"register_kernel({op}): got unexpected type for op: {type(op)}" + ) + if isinstance(op, torch._ops.OpOverload): + op = op._name + opdef = _maybe_get_opdef(op) + if opdef is not None: + return opdef.register_kernel(device_types, func) + assert isinstance(op, str) + if device_types is None: + device_types = "CompositeExplicitAutograd" + + return _impl(op, device_types, func, lib=lib, disable_dynamo=True) + + +def register_autocast( + op: _op_identifier, + device_type: str, + cast_inputs: _dtype, + /, + *, + lib: Library | None = None, +): + r"""Register an autocast dispatch rule for this custom op. + + Valid `device_type` include: "cpu" and "cuda". + + Args: + op (str | OpOverload): The operator to register an autocast dispatch rule to. + device_type(str): Device type to use. 'cuda' or 'cpu'. + The type is the same as the `type` attribute of a :class:`torch.device`. + Thus, you may obtain the device type of a tensor using `Tensor.device.type`. + cast_inputs (:class:`torch.dtype`): When custom op runs in an autocast-enabled region, + casts incoming floating-point Tensors to the target dtype (non-floating-point Tensors + are not affected), then executes custom op with autocast disabled. + lib (Optional[Library]): If provided, the lifetime of this registration + + Examples:: + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) + >>> import torch + >>> from torch import Tensor + >>> from torch.library import custom_op + >>> + >>> # Create a custom op that works on cuda + >>> @torch.library.custom_op("mylib::my_sin", mutates_args=()) + >>> def my_sin(x: Tensor) -> Tensor: + >>> return torch.sin(x) + >>> + >>> # Register autocast dispatch rule for the cuda device + >>> torch.library.register_autocast("mylib::my_sin", "cuda", torch.float16) + >>> + >>> x = torch.randn(3, dtype=torch.float32, device="cuda") + >>> with torch.autocast("cuda", dtype=torch.float16): + >>> y = torch.ops.mylib.my_sin(x) + >>> assert y.dtype == torch.float16 + + """ + if not isinstance( + op, (str, torch._ops.OpOverload, torch._library.custom_ops.CustomOpDef) + ): + raise ValueError( + f"register_autocast({op}): got unexpected type for op: {type(op)}" + ) + if device_type not in ["cpu", "cuda"]: + raise ValueError(f"Unknown device type: {device_type}") + + if isinstance(op, torch._ops.OpOverload): + op = op._name + opdef = _maybe_get_opdef(op) + if opdef is not None: + return opdef.register_autocast(device_type, cast_inputs) + + assert isinstance(op, str) + qualname = op + _op = torch._library.utils.lookup_op(qualname) + + namespace, opname = torch._library.utils.parse_namespace(qualname) + if lib is None: + lib = Library(namespace, "FRAGMENT") + _keep_alive.append(lib) + + def _maybe_override_py_impl(op: torch._ops.OpOverload, dispatch_key): + def inner(kernel): + if op.has_kernel_for_dispatch_key(dispatch_key): + op.py_kernels.pop(dispatch_key) + return op.py_impl(dispatch_key)(kernel) + + return inner + + @_maybe_override_py_impl(_op, torch._C.DispatchKey.AutocastCPU) + @_maybe_override_py_impl(_op, torch._C.DispatchKey.AutocastCUDA) + def _autocast_py_impl(*args, **kwargs): + assert len(kwargs) == 0, "Custom ops do not support kwargs yet." + autocast_keyset = torch._C.DispatchKeySet( + torch._C.DispatchKey.AutocastCPU + ) | torch._C.DispatchKeySet(torch._C.DispatchKey.AutocastCUDA) + with torch._C._ExcludeDispatchKeyGuard(autocast_keyset): + return _op(*_cast(args, device_type, cast_inputs)) + + def kernel(_, *args, **kwargs): + assert len(kwargs) == 0, "Custom ops do not support kwargs yet." + return _autocast_py_impl(*args, **kwargs) + + if device_type == "cuda": + return lib.impl(opname, kernel, "AutocastCUDA", with_keyset=True) + else: + # device_type is "cpu" + return lib.impl(opname, kernel, "AutocastCPU", with_keyset=True) + + +def register_fake( + op: _op_identifier, + func: Callable | None = None, + /, + *, + lib: Library | None = None, + _stacklevel: int = 1, + allow_override: bool = False, +): + r"""Register a FakeTensor implementation ("fake impl") for this operator. + + Also sometimes known as a "meta kernel", "abstract impl". + + An "FakeTensor implementation" specifies the behavior of this operator on + Tensors that carry no data ("FakeTensor"). Given some input Tensors with + certain properties (sizes/strides/storage_offset/device), it specifies + what the properties of the output Tensors are. + + The FakeTensor implementation has the same signature as the operator. + It is run for both FakeTensors and meta tensors. To write a FakeTensor + implementation, assume that all Tensor inputs to the operator are + regular CPU/CUDA/Meta tensors, but they do not have storage, and + you are trying to return regular CPU/CUDA/Meta tensor(s) as output. + The FakeTensor implementation must consist of only PyTorch operations + (and may not directly access the storage or data of any input or + intermediate Tensors). + + This API may be used as a decorator (see examples). + + For a detailed guide on custom ops, please see + https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html + + Args: + op_name: Operator name (along with the overload) or OpOverload object. + func: Fake tensor implementation. + lib (Optional[Library]): Library to register the fake tensor to. + allow_override: Flag controlling if we want to override an + existing registered fake impl. This is by default off, + and will error you're trying to register a fake impl to + an operator that already has a fake impl. This also only + applies if the custom operator was not created via + torch.library.custom_op, as overriding and existing fake + impl is already allowed. + + Examples: + >>> import torch + >>> import numpy as np + >>> from torch import Tensor + >>> + >>> # Example 1: an operator without data-dependent output shape + >>> @torch.library.custom_op("mylib::custom_linear", mutates_args=()) + >>> def custom_linear(x: Tensor, weight: Tensor, bias: Tensor) -> Tensor: + >>> raise NotImplementedError("Implementation goes here") + >>> + >>> @torch.library.register_fake("mylib::custom_linear") + >>> def _(x, weight, bias): + >>> assert x.dim() == 2 + >>> assert weight.dim() == 2 + >>> assert bias.dim() == 1 + >>> assert x.shape[1] == weight.shape[1] + >>> assert weight.shape[0] == bias.shape[0] + >>> assert x.device == weight.device + >>> + >>> return (x @ weight.t()) + bias + >>> + >>> with torch._subclasses.fake_tensor.FakeTensorMode(): + >>> x = torch.randn(2, 3) + >>> w = torch.randn(3, 3) + >>> b = torch.randn(3) + >>> y = torch.ops.mylib.custom_linear(x, w, b) + >>> + >>> assert y.shape == (2, 3) + >>> + >>> # Example 2: an operator with data-dependent output shape + >>> @torch.library.custom_op("mylib::custom_nonzero", mutates_args=()) + >>> def custom_nonzero(x: Tensor) -> Tensor: + >>> x_np = x.numpy(force=True) + >>> res = np.stack(np.nonzero(x_np), axis=1) + >>> return torch.tensor(res, device=x.device) + >>> + >>> @torch.library.register_fake("mylib::custom_nonzero") + >>> def _(x): + >>> # Number of nonzero-elements is data-dependent. + >>> # Since we cannot peek at the data in an fake impl, + >>> # we use the ctx object to construct a new symint that + >>> # represents the data-dependent size. + >>> ctx = torch.library.get_ctx() + >>> nnz = ctx.new_dynamic_size() + >>> shape = [nnz, x.dim()] + >>> result = x.new_empty(shape, dtype=torch.int64) + >>> return result + >>> + >>> from torch.fx.experimental.proxy_tensor import make_fx + >>> + >>> x = torch.tensor([0, 1, 2, 3, 4, 0]) + >>> trace = make_fx(torch.ops.mylib.custom_nonzero, tracing_mode="symbolic")(x) + >>> trace.print_readable() + >>> + >>> assert torch.allclose(trace(x), torch.ops.mylib.custom_nonzero(x)) + + """ + if not isinstance( + op, (str, torch._ops.OpOverload, torch._library.custom_ops.CustomOpDef) + ): + raise ValueError(f"register_fake({op}): got unexpected type for op: {type(op)}") + if isinstance(op, torch._ops.OpOverload): + op = op._name + opdef = _maybe_get_opdef(op) + if opdef is not None: + if func is None: + return opdef.register_fake + else: + return opdef.register_fake(func) + assert isinstance(op, str) + + stacklevel = _stacklevel + + def register(func): + namespace, op_name = torch._library.utils.parse_namespace(op) + if lib is None: + use_lib = Library(namespace, "FRAGMENT") + _keep_alive.append(use_lib) + else: + use_lib = lib + use_lib._register_fake( + op_name, func, _stacklevel=stacklevel + 1, allow_override=allow_override + ) + return func + + if func is None: + return register + else: + stacklevel += 1 + return register(func) + + +def _register_effectful_op( + op: _op_identifier, + effect: EffectType | None, + *, + lib: Library | None = None, +) -> None: + r""" + To specify that an operator has side-effects, we must register an effect + type for the operator. This will prevent graph passes in torch.compile from + reordering operations with the same effect type. + + Args: + op_name: Operator name (along with the overload) or OpOverload object. + effect: Effect type to register. None means the operator is not effectful. + """ + if not isinstance( + op, (str, torch._ops.OpOverload, torch._library.custom_ops.CustomOpDef) + ): + raise ValueError( + f"register_effectful_op({op}): got unexpected type for op: {type(op)}" + ) + + if isinstance(op, torch._ops.OpOverload): + op = op._name + opdef = _maybe_get_opdef(op) + if opdef is not None: + opdef.register_effect(effect) + assert isinstance(op, str) + + namespace, _ = torch._library.utils.parse_namespace(op) + if lib is None: + use_lib = Library(namespace, "FRAGMENT") + _keep_alive.append(use_lib) + else: + use_lib = lib + use_lib._register_effectful_op(op, effect) + + +def register_autograd( + op: _op_identifier, + backward: Callable, + /, + *, + setup_context: Callable | None = None, + lib=None, +) -> None: + r"""Register a backward formula for this custom op. + + In order for an operator to work with autograd, you need to register + a backward formula: + 1. You must tell us how to compute gradients during the backward pass + by providing us a "backward" function. + 2. If you need any values from the forward to compute gradients, you can + use `setup_context` to save values for backward. + + ``backward`` runs during the backward pass. It accepts ``(ctx, *grads)``: + - ``grads`` is one or more gradients. The number of gradients matches + the number of outputs of the operator. + The ``ctx`` object is `the same ctx object `_ used by + :class:`torch.autograd.Function`. The semantics of ``backward_fn`` are the + same as :meth:`torch.autograd.Function.backward`. + + ``setup_context(ctx, inputs, output)`` runs during the forward pass. + Please save quantities needed for backward onto the ``ctx`` object via + either :meth:`torch.autograd.function.FunctionCtx.save_for_backward` + or assigning them as attributes of ``ctx``. If your custom op has + kwarg-only arguments, we expect the signature of ``setup_context`` + to be ``setup_context(ctx, inputs, keyword_only_inputs, output)``. + + Both ``setup_context_fn`` and ``backward_fn`` must be traceable. That is, + they may not directly access :meth:`torch.Tensor.data_ptr` and they must + not depend on or mutate global state. If you need a non-traceable backward, + you can make it a separate custom_op that you call inside ``backward_fn``. + + If you need different autograd behavior on different devices, then we + recommend creating two different custom operators, one for each device + that needs different behavior, and switching between them at runtime. + + Examples: + >>> import torch + >>> import numpy as np + >>> from torch import Tensor + >>> + >>> @torch.library.custom_op("mylib::numpy_sin", mutates_args=()) + >>> def numpy_sin(x: Tensor) -> Tensor: + >>> x_np = x.cpu().numpy() + >>> y_np = np.sin(x_np) + >>> return torch.from_numpy(y_np).to(device=x.device) + >>> + >>> def setup_context(ctx, inputs, output) -> Tensor: + >>> x, = inputs + >>> ctx.save_for_backward(x) + >>> + >>> def backward(ctx, grad): + >>> x, = ctx.saved_tensors + >>> return grad * x.cos() + >>> + >>> torch.library.register_autograd( + ... "mylib::numpy_sin", backward, setup_context=setup_context + ... ) + >>> + >>> x = torch.randn(3, requires_grad=True) + >>> y = numpy_sin(x) + >>> (grad_x,) = torch.autograd.grad(y, x, torch.ones_like(y)) + >>> assert torch.allclose(grad_x, x.cos()) + >>> + >>> # Example with a keyword-only arg + >>> @torch.library.custom_op("mylib::numpy_mul", mutates_args=()) + >>> def numpy_mul(x: Tensor, *, val: float) -> Tensor: + >>> x_np = x.cpu().numpy() + >>> y_np = x_np * val + >>> return torch.from_numpy(y_np).to(device=x.device) + >>> + >>> def setup_context(ctx, inputs, keyword_only_inputs, output) -> Tensor: + >>> ctx.val = keyword_only_inputs["val"] + >>> + >>> def backward(ctx, grad): + >>> return grad * ctx.val + >>> + >>> torch.library.register_autograd( + ... "mylib::numpy_mul", backward, setup_context=setup_context + ... ) + >>> + >>> x = torch.randn(3, requires_grad=True) + >>> y = numpy_mul(x, val=3.14) + >>> (grad_x,) = torch.autograd.grad(y, x, torch.ones_like(y)) + >>> assert torch.allclose(grad_x, torch.full_like(x, 3.14)) + + """ + if not isinstance( + op, (str, torch._ops.OpOverload, torch._library.custom_ops.CustomOpDef) + ): + raise ValueError( + f"register_autograd({op}): got unexpected type for op: {type(op)}" + ) + if isinstance(op, torch._ops.OpOverload): + op = op._name + opdef = _maybe_get_opdef(op) + if opdef is not None: + opdef.register_autograd(backward, setup_context=setup_context) + return + + assert isinstance(op, str) + qualname = op + op = torch._library.utils.lookup_op(qualname) + schema = op._schema + if not _library.utils.is_functional_schema(schema): + raise RuntimeError( + f"Cannot register autograd formula for non-functional operator " + f"{op} with schema {schema}. Please create " + f"a functional operator and register an autograd formula for that." + ) + if _library.utils.has_kwarg_only_tensors(schema): + raise NotImplementedError( + f"register_autograd with kwarg-only Tensor args. In the original " + f"definition of the op, please make your tensors not kwarg-only. " + f"Got: {schema}" + ) + + info = _library.autograd.Info(backward, setup_context) + autograd_kernel = _library.autograd.make_autograd_impl(op, info) + namespace, opname = torch._library.utils.parse_namespace(qualname) + if lib is None: + lib = Library(namespace, "FRAGMENT") + _keep_alive.append(lib) + lib.impl(opname, autograd_kernel, "Autograd", with_keyset=True) + + +def register_torch_dispatch( + op: _op_identifier, + torch_dispatch_class: Any, + func: Callable | None = None, + /, + *, + lib: Library | None = None, +): + r"""Registers a torch_dispatch rule for the given operator and ``torch_dispatch_class``. + + This allows for open registration to specify the behavior between the operator + and the ``torch_dispatch_class`` without needing to modify the ``torch_dispatch_class`` + or the operator directly. + + The ``torch_dispatch_class`` is either a Tensor subclass with ``__torch_dispatch__`` or a + TorchDispatchMode. + + If it is a Tensor subclass, we expect ``func`` to have the following signature: + ``(cls, func: OpOverload, types: Tuple[type, ...], args, kwargs) -> Any`` + + If it is a TorchDispatchMode, we expect ``func`` to have the following signature: + ``(mode, func: OpOverload, types: Tuple[type, ...], args, kwargs) -> Any`` + + ``args`` and ``kwargs`` will have been normalized the same way they are + in ``__torch_dispatch__`` (see :ref:`torch-dispatch-calling-convention`). + + Examples: + + >>> import torch + >>> + >>> @torch.library.custom_op("mylib::foo", mutates_args={}) + >>> def foo(x: torch.Tensor) -> torch.Tensor: + >>> return x.clone() + >>> + >>> class MyMode(torch.utils._python_dispatch.TorchDispatchMode): + >>> def __torch_dispatch__(self, func, types, args=(), kwargs=None): + >>> return func(*args, **kwargs) + >>> + >>> @torch.library.register_torch_dispatch("mylib::foo", MyMode) + >>> def _(mode, func, types, args, kwargs): + >>> x, = args + >>> return x + 1 + >>> + >>> x = torch.randn(3) + >>> y = foo(x) + >>> assert torch.allclose(y, x) + >>> + >>> with MyMode(): + >>> y = foo(x) + >>> assert torch.allclose(y, x + 1) + + """ + if not isinstance( + op, (str, torch._ops.OpOverload, torch._library.custom_ops.CustomOpDef) + ): + raise ValueError( + f"register_torch_dispatch({op}): got unexpected type for op: {type(op)}" + ) + if isinstance(op, torch._ops.OpOverload): + op = op._name + opdef = _maybe_get_opdef(op) + if opdef is not None: + return opdef.register_torch_dispatch(torch_dispatch_class, func) + assert isinstance(op, str) + + def register(func): + namespace, op_name = torch._library.utils.parse_namespace(op) + if lib is None: + use_lib = Library(namespace, "FRAGMENT") + _keep_alive.append(use_lib) + else: + use_lib = lib + use_lib._register_torch_dispatch_rule(op_name, torch_dispatch_class, func) + return func + + if func is None: + return register + else: + return register(func) + + +def register_vmap( + op: _op_identifier, + func: Callable | None = None, + /, + *, + lib=None, +): + r"""Register a vmap implementation to support :func:`torch.vmap` for this custom op. + + This API may be used as a decorator (see examples). + + In order for an operator to work with :func:`torch.vmap`, you may need to register a + vmap implementation in the following signature: + + ``vmap_func(info, in_dims: Tuple[Optional[int]], *args, **kwargs)``, + + where ``*args`` and ``**kwargs`` are the arguments and kwargs for ``op``. + We do not support kwarg-only Tensor args. + + It specifies how do we compute the batched version of ``op`` given inputs with an additional + dimension (specified by ``in_dims``). + + For each arg in ``args``, ``in_dims`` has a corresponding ``Optional[int]``. It is ``None`` + if the arg is not a Tensor or if the arg is not being vmapped over, otherwise, it is an integer + specifying what dimension of the Tensor is being vmapped over. + + ``info`` is a collection of additional metadata that may be helpful: + ``info.batch_size`` specifies the size of the dimension being vmapped over, while + ``info.randomness`` is the ``randomness`` option that was passed to :func:`torch.vmap`. + + The return of the function ``func`` is a tuple of ``(output, out_dims)``. Similar to ``in_dims``, + ``out_dims`` should be of the same structure as ``output`` and contain one ``out_dim`` + per output that specifies if the output has the vmapped dimension and what index it is in. + + Examples: + >>> import torch + >>> import numpy as np + >>> from torch import Tensor + >>> from typing import Tuple + >>> + >>> def to_numpy(tensor): + >>> return tensor.cpu().numpy() + >>> + >>> lib = torch.library.Library("mylib", "FRAGMENT") + >>> @torch.library.custom_op("mylib::numpy_cube", mutates_args=()) + >>> def numpy_cube(x: Tensor) -> Tuple[Tensor, Tensor]: + >>> x_np = to_numpy(x) + >>> dx = torch.tensor(3 * x_np ** 2, device=x.device) + >>> return torch.tensor(x_np ** 3, device=x.device), dx + >>> + >>> def numpy_cube_vmap(info, in_dims, x): + >>> result = numpy_cube(x) + >>> return result, (in_dims[0], in_dims[0]) + >>> + >>> torch.library.register_vmap(numpy_cube, numpy_cube_vmap) + >>> + >>> x = torch.randn(3) + >>> torch.vmap(numpy_cube)(x) + >>> + >>> @torch.library.custom_op("mylib::numpy_mul", mutates_args=()) + >>> def numpy_mul(x: Tensor, y: Tensor) -> Tensor: + >>> return torch.tensor(to_numpy(x) * to_numpy(y), device=x.device) + >>> + >>> @torch.library.register_vmap("mylib::numpy_mul") + >>> def numpy_mul_vmap(info, in_dims, x, y): + >>> x_bdim, y_bdim = in_dims + >>> x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1) + >>> y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1) + >>> result = x * y + >>> result = result.movedim(-1, 0) + >>> return result, 0 + >>> + >>> + >>> x = torch.randn(3) + >>> y = torch.randn(3) + >>> torch.vmap(numpy_mul)(x, y) + + .. note:: + The vmap function should aim to preserve the semantics of the entire custom operator. + That is, ``grad(vmap(op))`` should be replaceable with a ``grad(map(op))``. + + If your custom operator has any custom behavior in the backward pass, please + keep this in mind. + + """ + if not isinstance( + op, (str, torch._ops.OpOverload, torch._library.custom_ops.CustomOpDef) + ): + raise ValueError(f"register_vmap({op}): got unexpected type for op: {type(op)}") + if isinstance(op, torch._ops.OpOverload): + op = op._name + opdef = _maybe_get_opdef(op) + if opdef is not None: + return opdef.register_vmap(func) + assert isinstance(op, str) + qualname = op + op = torch._library.utils.lookup_op(qualname) + schema = op._schema + if _library.utils.has_kwarg_only_tensors(schema): + raise NotImplementedError( + f"register_vmap with kwarg-only Tensor args. In the original " + f"definition of the op, please make your tensors not kwarg-only. " + f"Got: {schema}" + ) + + def register(func): + nonlocal op, lib + + namespace, opname = torch._library.utils.parse_namespace(qualname) + if lib is None: + lib = Library(namespace, "FRAGMENT") + _keep_alive.append(lib) + + from torch._functorch.autograd_function import custom_function_call_vmap_helper + from torch._functorch.pyfunctorch import retrieve_current_functorch_interpreter + + def wrapped_func(keyset, *args, **kwargs): + interpreter = retrieve_current_functorch_interpreter() + return custom_function_call_vmap_helper( + interpreter, func, op, *args, **kwargs + ) + + lib.impl(opname, wrapped_func, "FuncTorchBatched", with_keyset=True) + + if func is None: + return register + else: + return register(func) + + +# If the op was defined in C++, then we want to make sure there was an +# m.set_python_module(module, ...) call and that the module is the +# same as the module that called torch.library.register_fake. +def _check_pystubs_once(func, qualname, actual_module_name): + checked = False + + def inner(*args, **kwargs): + nonlocal checked + if checked: + return func(*args, **kwargs) + + op = torch._library.utils.lookup_op(qualname) + if op._defined_in_python: + checked = True + return func(*args, **kwargs) + + maybe_pystub = torch._C._dispatch_pystub( + op._schema.name, op._schema.overload_name + ) + if maybe_pystub is None: + if torch._library.utils.requires_set_python_module(): + namespace = op.namespace + cpp_filename = op._handle.debug() + raise RuntimeError( + f"Operator '{qualname}' was defined in C++ and has a Python " + f"fake impl. In this situation, we require there to also be a " + f'companion C++ `m.set_python_module("{actual_module_name}")` ' + f"call, but we could not find one. Please add that to " + f"to the top of the C++ TORCH_LIBRARY({namespace}, ...) block the " + f"operator was registered in ({cpp_filename})" + ) + else: + pystub_module = maybe_pystub[0] + if actual_module_name != pystub_module: + cpp_filename = op._handle.debug() + raise RuntimeError( + f"Operator '{qualname}' specified that its python fake impl " + f"is in the Python module '{pystub_module}' but it was actually found " + f"in '{actual_module_name}'. Please either move the fake impl " + f"or correct the m.set_python_module call ({cpp_filename})" + ) + checked = True + return func(*args, **kwargs) + + return inner + + +# NOTE [ctx inside the fake implementation] +# If a user has an operator with data-dependent output shape, then when writing +# a fake implementation they must query the current ctx and use methods on the +# ctx to construct a new unbacked symint. +# +# This is done via us setting the global_ctx_getter function every time a fake +# implementation is invoked. +def get_ctx() -> "torch._library.fake_impl.FakeImplCtx": + """get_ctx() returns the current AbstractImplCtx object. + + Calling ``get_ctx()`` is only valid inside of an fake impl + (see :func:`torch.library.register_fake` for more usage details. + """ + return torch._library.fake_impl.global_ctx_getter() + + +def get_kernel( + op: _op_identifier, dispatch_key: str | torch.DispatchKey +) -> torch._C._SafeKernelFunction: + """Returns the computed kernel for a given operator and dispatch key. + + This function retrieves the kernel that would be executed for a given + operator and dispatch key combination. The returned SafeKernelFunction + can be used to call the kernel in a boxed fashion. The intended use + case for this function is to retrieve the original kernel for a given + dispatch key and then register another kernel to the same dispatch key + that calls into the original kernel for certain cases. + + Args: + op: Operator name (along with the overload) or OpOverload object + Can be a string (e.g., "aten::add.Tensor"), an OpOverload, or a CustomOpDef. + dispatch_key (str | torch.DispatchKey): The dispatch key to get the kernel for. + Can be a string (e.g., "CPU", "CUDA") or a DispatchKey enum value. + + Returns: + torch._C._SafeKernelFunction: A safe kernel function that can be used to + call the kernel. + + Raises: + RuntimeError: If the operator does not exist. + + Example: + >>> # Get the CPU kernel for torch.add + >>> kernel = torch.library.get_kernel("aten::add.Tensor", "CPU") + >>> + >>> # You can also use DispatchKey enum + >>> kernel = torch.library.get_kernel("aten::add.Tensor", torch.DispatchKey.CPU) + >>> + >>> # Or use an OpOverload directly + >>> kernel = torch.library.get_kernel(torch.ops.aten.add.Tensor, "CPU") + >>> + >>> # Example: Using get_kernel in a custom op with conditional dispatch + >>> # Get the original kernel for torch.sin + >>> original_sin_kernel = torch.library.get_kernel("aten::sin", "CPU") + >>> + >>> # If input has negative values, use original sin, otherwise return zeros + >>> def conditional_sin_impl(dispatch_keys, x): + >>> if (x < 0).any(): + >>> return original_sin_kernel.call_boxed(dispatch_keys, x) + >>> else: + >>> return torch.zeros_like(x) + >>> + >>> lib = torch.library.Library("aten", "IMPL") + >>> # with_keyset=True so the first argument to the impl is the current DispatchKeySet + >>> which needs to be the first argument to ``kernel.call_boxed`` + >>> lib.impl("sin", conditional_sin_impl, "CPU", with_keyset=True) + >>> + >>> # Test the conditional behavior + >>> x_positive = torch.tensor([1.0, 2.0]) + >>> x_mixed = torch.tensor([-1.0, 2.0]) + >>> torch.sin(x_positive) + tensor([0., 0.]) + >>> torch.sin(x_mixed) + tensor([-0.8415, 0.9093]) + """ + if not isinstance(op, (str, torch._ops.OpOverload)): + raise ValueError(f"get_kernel({op}): got unexpected type for op: {type(op)}") + + if isinstance(op, torch._ops.OpOverload): + op = op._name + + if isinstance(dispatch_key, str): + try: + dispatch_key = torch._C.DispatchKey.__members__[dispatch_key] + except KeyError: + raise ValueError(f"Invalid dispatch key: {dispatch_key}") from None + + return torch._C._dispatch_get_computed_kernel_for_dispatch_key(op, dispatch_key) + + +_OPCHECK_DEFAULT_UTILS = ( + "test_schema", + "test_autograd_registration", + "test_faketensor", + "test_aot_dispatch_dynamic", +) + + +def opcheck( + op: torch._ops.OpOverload | torch._ops.OpOverloadPacket | CustomOpDef, + args: tuple[Any, ...], + kwargs: dict[str, Any] | None = None, + *, + test_utils: str | Sequence[str] = _OPCHECK_DEFAULT_UTILS, + raise_exception: bool = True, + atol=None, + rtol=None, +) -> dict[str, str]: + """Given an operator and some sample arguments, tests if the operator is + registered correctly. + + That is, when you use the torch.library/TORCH_LIBRARY APIs to create a + custom op, you specified metadata (e.g. mutability info) about the custom op + and these APIs require that the functions you pass them satisfy certain + properties (e.g. no data pointer access in the fake/meta/abstract kernel) + ``opcheck`` tests these metadata and properties. + + Concretely, we test the following: + + - test_schema: If the schema matches the implementation of + the operator. For example: if the schema specifies a Tensor is mutated, + then we check the implementation mutates the Tensor. If the schema + specifies that we return a new Tensor, then we check that the + implementation returns a new Tensor (instead of an existing one or + a view of an existing one). + - test_autograd_registration: If the operator supports training + (autograd): we check that its autograd formula is registered via + torch.library.register_autograd or a manual registration to one + or more DispatchKey::Autograd keys. Any other DispatchKey-based + registrations may lead to undefined behavior. + - test_faketensor: If the operator has a FakeTensor kernel + (and if it is correct). The FakeTensor kernel is necessary ( + but not sufficient) for the operator to work with PyTorch compilation + APIs (torch.compile/export/FX). We check that a FakeTensor kernel + (also sometimes known as a meta kernel) was registered for the + operator and that it is correct. This test takes the result of + running the operator on real tensors and the result of running + the operator on FakeTensors and checks that they have the same + Tensor metadata (sizes/strides/dtype/device/etc). + - test_aot_dispatch_dynamic: If the operator has correct behavior + with PyTorch compilation APIs (torch.compile/export/FX). + This checks that the outputs (and gradients, if applicable) are the + same under eager-mode PyTorch and torch.compile. + This test is a superset of ``test_faketensor`` and is an e2e test; + other things it tests are that the operator supports + functionalization and that the backward pass (if it exists) also + supports FakeTensor and functionalization. + + For best results, please call ``opcheck`` multiple times with a + representative set of inputs. If your operator supports + autograd, please use ``opcheck`` with inputs with ``requires_grad = True``; + if your operator supports multiple devices (e.g. CPU and CUDA), please + use ``opcheck`` with inputs on all supported devices. + + Args: + op: The operator. Must either be a function decorated with + :func:`torch.library.custom_op` or an OpOverload/OpOverloadPacket + found in torch.ops.* (e.g. torch.ops.aten.sin, torch.ops.mylib.foo) + args: The args to the operator + kwargs: The kwargs to the operator + test_utils: Tests that we should run. Default: all of them. + Example: ("test_schema", "test_faketensor") + raise_exception: If we should raise an exception on the first + error. If False, we will return a dict with information + on if each test passed or not. + rtol (Optional[float]): Relative tolerance for floating point comparisons. + If specified ``atol`` must also be specified. + If omitted, default values based on the ``dtype`` are selected + (see the table in :func:`torch.testing.assert_close`). + atol (Optional[float]): Absolute tolerance for floating point comparisons. + If specified ``rtol`` must also be specified. + If omitted, default values based on the ``dtype`` are selected + (see the table in :func:`torch.testing.assert_close`). + + .. warning:: + + opcheck and :func:`torch.autograd.gradcheck` test different things; + opcheck tests if your usage of torch.library APIs is correct while + :func:`torch.autograd.gradcheck` tests if your autograd formula is + mathematically correct. Use both to test custom ops that support + gradient computation. + + Example: + + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) + >>> @torch.library.custom_op("mylib::numpy_mul", mutates_args=()) + >>> def numpy_mul(x: Tensor, y: float) -> Tensor: + >>> x_np = x.numpy(force=True) + >>> z_np = x_np * y + >>> return torch.from_numpy(z_np).to(x.device) + >>> + >>> @numpy_mul.register_fake + >>> def _(x, y): + >>> return torch.empty_like(x) + >>> + >>> def setup_context(ctx, inputs, output): + >>> y, = inputs + >>> ctx.y = y + >>> + >>> def backward(ctx, grad): + >>> return grad * ctx.y, None + >>> + >>> numpy_mul.register_autograd(backward, setup_context=setup_context) + >>> + >>> sample_inputs = [ + >>> (torch.randn(3), 3.14), + >>> (torch.randn(2, 3, device='cuda'), 2.718), + >>> (torch.randn(1, 10, requires_grad=True), 1.234), + >>> (torch.randn(64, 64, device='cuda', requires_grad=True), 90.18), + >>> ] + >>> + >>> for args in sample_inputs: + >>> torch.library.opcheck(numpy_mul, args) + + """ + import torch.testing._internal.optests as optests + + return optests.opcheck( + op, + args, + kwargs, + test_utils=test_utils, + raise_exception=raise_exception, + rtol=rtol, + atol=atol, + ) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/return_types.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/return_types.py new file mode 100644 index 0000000000000000000000000000000000000000..d456742be4b88ebdca9f3696a415014a500cdd33 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/return_types.py @@ -0,0 +1,51 @@ +import inspect + +import torch +from torch.utils._pytree import register_pytree_node, SequenceKey + + +__all__ = ["pytree_register_structseq", "all_return_types"] + +all_return_types = [] + +# error: Module has no attribute "_return_types" +return_types = torch._C._return_types # type: ignore[attr-defined] + + +def pytree_register_structseq(cls): + def structseq_flatten(structseq): + return list(structseq), None + + def structseq_flatten_with_keys(structseq): + values, context = structseq_flatten(structseq) + return [(SequenceKey(i), v) for i, v in enumerate(values)], context + + def structseq_unflatten(values, context): + return cls(values) + + register_pytree_node( + cls, + structseq_flatten, + structseq_unflatten, + flatten_with_keys_fn=structseq_flatten_with_keys, + ) + + +for name in dir(return_types): + if name.startswith("__"): + continue + + _attr = getattr(return_types, name) + globals()[name] = _attr + + if not name.startswith("_"): + __all__.append(name) + all_return_types.append(_attr) + + # Today everything in torch.return_types is a structseq, aka a "namedtuple"-like + # thing defined by the Python C-API. We're going to need to modify this when that + # is no longer the case. + # NB: I don't know how to check that something is a "structseq" so we do a fuzzy + # check for tuple + if inspect.isclass(_attr) and issubclass(_attr, tuple): + pytree_register_structseq(_attr) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/serialization.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/serialization.py new file mode 100644 index 0000000000000000000000000000000000000000..1a6acc8010634ec9f2fcfc1d24f34f2dbe31b8c9 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/serialization.py @@ -0,0 +1,2154 @@ +# mypy: allow-untyped-defs +import copyreg +import difflib +import functools +import io +import os +import pickle +import re +import shutil +import struct +import sys +import tarfile +import tempfile +import threading +import warnings +from collections.abc import Callable +from contextlib import closing, contextmanager +from enum import Enum +from typing import Any, cast, Generic, IO, TypeAlias, TypeVar +from typing_extensions import TypeIs + +import torch +import torch._weights_only_unpickler as _weights_only_unpickler +from torch._sources import get_source_lines_and_file +from torch._utils import _import_dotted_name +from torch.storage import _get_dtype_from_pickle_storage_type +from torch.types import FileLike, Storage + + +__all__ = [ + "SourceChangeWarning", + "mkdtemp", + "register_package", + "check_module_version_greater_or_equal", + "validate_cuda_device", + "validate_hpu_device", + "location_tag", + "default_restore_location", + "normalize_storage_type", + "storage_to_tensor_type", + "save", + "load", + "StorageType", + "LoadEndianness", + "get_crc32_options", + "set_crc32_options", + "get_default_load_endianness", + "set_default_load_endianness", + "get_default_mmap_options", + "set_default_mmap_options", + "clear_safe_globals", + "get_safe_globals", + "add_safe_globals", + "safe_globals", + "get_unsafe_globals_in_checkpoint", + "skip_data", +] + +DEFAULT_PROTOCOL = 2 + +LONG_SIZE = struct.Struct("=l").size +INT_SIZE = struct.Struct("=i").size +SHORT_SIZE = struct.Struct("=h").size + +MAGIC_NUMBER = 0x1950A86A20F9469CFC6C +PROTOCOL_VERSION = 1001 +STORAGE_KEY_SEPARATOR = "," + +MAP_LOCATION: TypeAlias = ( + Callable[[Storage, str], Storage] | torch.device | str | dict[str, str] | None +) +STORAGE: TypeAlias = Storage | torch.storage.TypedStorage | torch.UntypedStorage + +IS_WINDOWS = sys.platform == "win32" + +UNSAFE_MESSAGE = ( + "In PyTorch 2.6, we changed the default value of the `weights_only` argument in `torch.load` " + "from `False` to `True`. Re-running `torch.load` with `weights_only` set to `False` will likely succeed, " + "but it can result in arbitrary code execution. Do it only if you got the file from a " + "trusted source." +) + +if not IS_WINDOWS: + from mmap import MAP_PRIVATE, MAP_SHARED +else: + MAP_SHARED, MAP_PRIVATE = None, None # type: ignore[assignment] + + +def _default_to_weights_only(pickle_module): + is_fbcode = not hasattr(torch.version, "git_version") + return pickle_module is None and not is_fbcode + + +# _serialization_tls is used to store thread local state specific to serialization +# that needs to be propagated to other files, in particular we use this for +# (1) map_location (needed for wrapper subclasses/third party devices to torch._utils) +# (2) skip_data (needed for torch.Tensor.__reduce_ex__ for skip_data ctx) +# (3) materialize_fake_tensors (needed for torch.Tensor.__reduce_ex__ for skip_data ctx) +class _SerializationLocal(threading.local): + def __init__(self): + super().__init__() + self.map_location: MAP_LOCATION | None = None + self.skip_data: bool = False + self.materialize_fake_tensors: bool = False + + +_serialization_tls = _SerializationLocal() + + +class SourceChangeWarning(Warning): + pass + + +@contextmanager +def mkdtemp(): + path = tempfile.mkdtemp() + try: + yield path + finally: + shutil.rmtree(path) + + +_package_registry: list[ + tuple[ + int, + Callable[[STORAGE], str | None], + Callable[[STORAGE, str], STORAGE | None], + ] +] = [] + + +class LoadEndianness(Enum): + NATIVE = 1 + LITTLE = 2 + BIG = 3 + + +def get_default_load_endianness() -> LoadEndianness | None: + """ + Get fallback byte order for loading files + + If byteorder mark is not present in saved checkpoint, + this byte order is used as fallback. + By default, it's "native" byte order. + + Returns: + default_load_endian: Optional[LoadEndianness] + """ + from torch.utils.serialization import config + + return config.load.endianness + + +def set_default_load_endianness(endianness): + """ + Set fallback byte order for loading files + + If byteorder mark is not present in saved checkpoint, + this byte order is used as fallback. + By default, it's "native" byte order. + + Args: + endianness: the new fallback byte order + """ + if not isinstance(endianness, LoadEndianness) and endianness is not None: + raise TypeError("Invalid argument type in function set_default_load_endianness") + from torch.utils.serialization import config + + config.load.endianness = endianness + + +def get_crc32_options() -> bool: + """ + Get whether :func:`torch.save` computes and writes crc32 for each record. + + Defaults to ``True``. + """ + from torch.utils.serialization import config + + return config.save.compute_crc32 + + +def set_crc32_options(compute_crc32: bool): + """ + Set whether :func:`torch.save` computes and writes crc32 for each record. + + .. note:: + Setting this to ``False`` may make unzipping of the ``torch.save`` output + fail or warn due to corrupted CRC32. However ``torch.load`` will be + able to load the file. + + Args: + compute_crc32 (bool): set crc32 computation flag + """ + from torch.utils.serialization import config + + config.save.compute_crc32 = compute_crc32 + + +def get_default_mmap_options() -> int | None: + """ + Get default mmap options for :func:`torch.load` with ``mmap=True``. + + Defaults to ``mmap.MAP_PRIVATE``. + + + Returns: + default_mmap_options: int + """ + from torch.utils.serialization import config + + return config.load.mmap_flags + + +def _get_storage_alignment() -> int: + """ + Gets alignment for storages in torch.save files/ + + Defaults to 64. + + Returns: + storage_alginment: int + """ + from torch.utils.serialization import config + + return config.save.storage_alignment + + +class set_default_mmap_options: + """ + Context manager or function to set default mmap options for :func:`torch.load` with ``mmap=True`` to flags. + + For now, only either ``mmap.MAP_PRIVATE`` or ``mmap.MAP_SHARED`` are supported. + Please open an issue if you need any other option to be added here. + + .. note:: + This feature is currently not supported for Windows. + + Args: + flags: ``mmap.MAP_PRIVATE`` or ``mmap.MAP_SHARED`` + """ + + def __init__(self, flags: int) -> None: + if IS_WINDOWS: + raise RuntimeError( + "Changing the default mmap options is currently not supported for Windows" + ) + if flags != MAP_PRIVATE and flags != MAP_SHARED: + raise ValueError( + "Invalid argument in function set_default_mmap_options, " + f"expected mmap.MAP_PRIVATE or mmap.MAP_SHARED, but got {flags}" + ) + # global config + from torch.utils.serialization import config + + self.prev = config.load.mmap_flags + config.load.mmap_flags = flags + + def __enter__(self) -> None: + pass + + def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: + from torch.utils.serialization import config + + config.load.mmap_flags = self.prev + + +def clear_safe_globals() -> None: + """ + Clears the list of globals that are safe for ``weights_only`` load. + """ + _weights_only_unpickler._clear_safe_globals() + + +def get_safe_globals() -> list[Callable | tuple[Callable, str]]: + """ + Returns the list of user-added globals that are safe for ``weights_only`` load. + """ + return _weights_only_unpickler._get_safe_globals() + + +def add_safe_globals(safe_globals: list[Callable | tuple[Callable, str]]) -> None: + """ + Marks the given globals as safe for ``weights_only`` load. For example, functions + added to this list can be called during unpickling, classes could be instantiated + and have state set. + + Each item in the list can either be a function/class or a tuple of the form + (function/class, string) where string is the full path of the function/class. + + Within the serialized format, each function is identified with its full + path as ``{__module__}.{__qualname__}``. When calling this API, you can provide this + full path that should match the one in the checkpoint otherwise the default + ``{fn.__module__}.{fn.__qualname__}`` will be used. + + Args: + safe_globals (List[Union[Callable, Tuple[Callable, str]]]): list of globals to mark as safe + + Example: + >>> # xdoctest: +SKIP("Can't torch.save(t, ...) as doctest thinks MyTensor is defined on torch.serialization") + >>> import tempfile + >>> class MyTensor(torch.Tensor): + ... pass + >>> t = MyTensor(torch.randn(2, 3)) + >>> with tempfile.NamedTemporaryFile() as f: + ... torch.save(t, f.name) + # Running `torch.load(f.name, weights_only=True)` will fail with + # Unsupported global: GLOBAL __main__.MyTensor was not an allowed global by default. + # Check the code and make sure MyTensor is safe to be used when loaded from an arbitrary checkpoint. + ... torch.serialization.add_safe_globals([MyTensor]) + ... torch.load(f.name, weights_only=True) + # MyTensor([[-0.5024, -1.8152, -0.5455], + # [-0.8234, 2.0500, -0.3657]]) + """ + _weights_only_unpickler._add_safe_globals(safe_globals) + + +class safe_globals(_weights_only_unpickler._safe_globals): + r"""Context-manager that adds certain globals as safe for ``weights_only`` load. + + Args: + safe_globals: List of globals for weights_only load. + + Example: + >>> # xdoctest: +SKIP("Can't torch.save(t, ...) as doctest thinks MyTensor is defined on torch.serialization") + >>> import tempfile + >>> class MyTensor(torch.Tensor): + ... pass + >>> t = MyTensor(torch.randn(2, 3)) + >>> with tempfile.NamedTemporaryFile() as f: + ... torch.save(t, f.name) + # Running `torch.load(f.name, weights_only=True)` will fail with + # Unsupported global: GLOBAL __main__.MyTensor was not an allowed global by default. + # Check the code and make sure MyTensor is safe to be used when loaded from an arbitrary checkpoint. + ... with torch.serialization.safe_globals([MyTensor]): + ... torch.load(f.name, weights_only=True) + # MyTensor([[-0.5024, -1.8152, -0.5455], + # [-0.8234, 2.0500, -0.3657]]) + >>> assert torch.serialization.get_safe_globals() == [] + """ + + +def get_unsafe_globals_in_checkpoint(f: FileLike) -> list[str]: + """Returns a list of strings of functions/classes in a ``torch.save`` object that are not safe for ``weights_only``. + + For a given function or class ``f``, the corresponding string will be of the form + ``{f.__module__}.{f.__name__}``. + + This function will return any GLOBALs in the checkpoint that are not in the set marked safe + for ``weights_only`` (either via :func:`add_safe_globals` or :class:`safe_globals` context or + allowlisted by ``torch`` by default). + + .. note:: + This function will statically disassemble the pickle file in the checkpoint. + The implication is any classes dynamically pushed onto the stack during unpickling + will not be included in the output. + + Args: + f: File-like object or string containing the checkpoint object saved via ``torch.save`` + + Returns: + A list of strings of pickle GLOBALs in the checkpoint that are not allowlisted for ``weights_only``. + """ + default_safe_globals_strings = set( + _weights_only_unpickler._get_allowed_globals().keys() + ) + user_safe_global_strings = set( + _weights_only_unpickler._get_user_allowed_globals().keys() + ) + safe_global_strings = default_safe_globals_strings.union(user_safe_global_strings) + + with _open_file_like(f, "rb") as opened_file: + if not _is_zipfile(opened_file): + raise ValueError("Expected input to be a checkpoint returned by torch.save") + with _open_zipfile_reader(opened_file) as zip_file: + if _is_torchscript_zip(zip_file): + raise ValueError( + "Expected input to be a checkpoint returned by torch.save but got a torchscript checkpoint" + ) + data_file = io.BytesIO(zip_file.get_record("data.pkl")) + all_globals = _weights_only_unpickler.get_globals_in_pkl(data_file) + return list(all_globals.difference(safe_global_strings)) + + +class skip_data: + """ + Context-manager that skips writing/reading storage bytes for ``torch.save`` / ``torch.load`` calls. + + For the save path, storages will still be saved, but the space that their bytes would usually be written to + will be empty space. The storage bytes can then be populated in a separate pass. + + For the load path, tensors will be loaded per the checkpoint but their storages will not be populated with data. + + .. warning:: + The ``skip_data`` context manager is an early prototype and is subject to change. + + Args: + materialize_fake_tensors: Whether to materialize FakeTensors during save. This is a no-op for the load path. + + Example: + >>> # xdoctest: +SKIP("NamedTemporaryFile on Windows") + >>> import tempfile + >>> t = torch.randn(2, 3) + >>> with tempfile.NamedTemporaryFile() as f: + ... with torch.serialization.skip_data(): + ... torch.save(t, f.name) + ... torch.load(f.name, weights_only=True) + tensor([[0., 0., 0.], + [0., 0., 0.]]) + """ + + def __init__(self, materialize_fake_tensors: bool = False): + self.materialize_fake_tensors = materialize_fake_tensors + + def __enter__(self): + global _serialization_tls + self._old_skip_data = _serialization_tls.skip_data + self._old_materialize_fake_tensors = _serialization_tls.materialize_fake_tensors + _serialization_tls.skip_data = True + _serialization_tls.materialize_fake_tensors = self.materialize_fake_tensors + + def __exit__(self, type, value, tb): + global _serialization_tls + _serialization_tls.skip_data = self._old_skip_data + _serialization_tls.materialize_fake_tensors = self._old_materialize_fake_tensors + + +def _is_zipfile(f) -> bool: + # This is a stricter implementation than zipfile.is_zipfile(). + # zipfile.is_zipfile() is True if the magic number appears anywhere in the + # binary. Since we expect the files here to be generated by torch.save or + # torch.jit.save, it's safe to only check the start bytes and avoid + # collisions and assume the zip has only 1 file. + # See bugs.python.org/issue28494. + + start = f.tell() + # Read the first few bytes and match against the ZIP file signature + local_header_magic_number = b"PK\x03\x04" + read_bytes = f.read(len(local_header_magic_number)) + f.seek(start) + return read_bytes == local_header_magic_number + + +def register_package( + priority: int, + tagger: Callable[[STORAGE], str | None], + deserializer: Callable[[STORAGE, str], STORAGE | None], +): + """ + Registers callables for tagging and deserializing storage objects with an associated priority. + Tagging associates a device with a storage object at save time while deserializing moves a + storage object to an appropriate device at load time. :attr:`tagger` and :attr:`deserializer` + are run in the order given by their :attr:`priority` until a tagger/deserializer returns a + value that is not `None`. + + To override the deserialization behavior for a device in the global registry, one can register a + tagger with a higher priority than the existing tagger. + + This function can also be used to register a tagger and deserializer for new devices. + + Args: + priority: Indicates the priority associated with the tagger and deserializer, where a lower + value indicates higher priority. + tagger: Callable that takes in a storage object and returns its tagged device as a string + or None. + deserializer: Callable that takes in storage object and a device string and returns a storage + object on the appropriate device or None. + + Returns: + `None` + + Example: + >>> def ipu_tag(obj): + >>> if obj.device.type == 'ipu': + >>> return 'ipu' + >>> def ipu_deserialize(obj, location): + >>> if location.startswith('ipu'): + >>> ipu = getattr(torch, "ipu", None) + >>> assert ipu is not None, "IPU device module is not loaded" + >>> assert torch.ipu.is_available(), "ipu is not available" + >>> return obj.ipu(location) + >>> torch.serialization.register_package(11, ipu_tag, ipu_deserialize) + """ + queue_elem = (priority, tagger, deserializer) + _package_registry.append(queue_elem) + _package_registry.sort() + + +def check_module_version_greater_or_equal( + module, + req_version_tuple, + error_if_malformed=True, +): + """ + Check if a module's version satisfies requirements + + Usually, a module's version string will be like 'x.y.z', which would be represented + as a tuple (x, y, z), but sometimes it could be an unexpected format. If the version + string does not match the given tuple's format up to the length of the tuple, then + error and exit or emit a warning. + + Args: + module: the module to check the version of + req_version_tuple: tuple (usually of ints) representing the required version + error_if_malformed: whether we should exit if module version string is malformed + + Returns: + requirement_is_met: bool + """ + try: + version_strs = module.__version__.split(".") + # Cast module version fields to match the types of the required version + module_version = tuple( + type(req_field)(version_strs[idx]) + for idx, req_field in enumerate(req_version_tuple) + ) + requirement_is_met = module_version >= req_version_tuple + + except Exception as e: + message = ( + f"'{module.__name__}' module version string is malformed '{module.__version__}' and cannot be compared" + f" with tuple {str(req_version_tuple)}" + ) + if error_if_malformed: + raise RuntimeError(message) from e + else: + warnings.warn( + message + ", but continuing assuming that requirement is met", + stacklevel=2, + ) + requirement_is_met = True + + return requirement_is_met + + +def _cpu_tag(obj): + if obj.device.type == "cpu": + return "cpu" + + +def _mps_tag(obj): + if obj.device.type == "mps": + return "mps" + + +def _meta_tag(obj): + if obj.device.type == "meta": + return "meta" + + +def _backend_tag(backend_name, obj): + if backend_name == "privateuse1": + backend_name = torch._C._get_privateuse1_backend_name() + if obj.device.type == backend_name: + if obj.device.index is None: + return backend_name + else: + return backend_name + ":" + str(obj.device.index) + + +def _cpu_deserialize(obj, location): + if location == "cpu": + return obj + + +def _mps_deserialize(obj, location): + if location.startswith("mps"): + return obj.mps() + + +def _meta_deserialize(obj, location): + if location == "meta": + return torch.UntypedStorage(obj.nbytes(), device="meta") + + +def _validate_device(location, backend_name): + """ + Check whether the device index of specified backend is valid + + In case of privateuse1 backend, your must first register a device_module for + privateuse1 using torch._register_device_module. Implement the following + methods in device_module like cuda: device_module._utils._get_device_index(location, True), + device_module.device_count(). + + Args: + location: string of device + backend_name: the backend name or the name of privateuse1, which can be renamed + + Returns: + device_index: int + """ + if not hasattr(torch, backend_name): + raise RuntimeError( + f"The {backend_name.upper()} device module is not registered. " + "If you are running on a CPU-only machine, " + "please use torch.load with map_location=torch.device('cpu') " + "to map your storages to the CPU." + ) + device_module = getattr(torch, backend_name) + if hasattr(device_module, "_utils") and hasattr( + device_module._utils, "_get_device_index" + ): + device_index = device_module._utils._get_device_index(location, True) + device = torch.device(backend_name, device_index) + else: + device = torch.device(location) + device_index = device.index if device.index else 0 + if hasattr(device_module, "is_available") and not device_module.is_available(): + raise RuntimeError( + f"Attempting to deserialize object on a {backend_name.upper()} " + f"device but torch.{backend_name}.is_available() is False. " + "If you are running on a CPU-only machine, " + "please use torch.load with map_location=torch.device('cpu') " + "to map your storages to the CPU." + ) + if hasattr(device_module, "device_count"): + device_count = device_module.device_count() + if device_index >= device_count: + raise RuntimeError( + f"Attempting to deserialize object on {backend_name.upper()} device " + f"{device_index} but torch.{backend_name}.device_count() is {device_count}. " + "Please use torch.load with map_location to map your storages " + "to an existing device." + ) + return device + + +def validate_cuda_device(location): + return _validate_device(location, "cuda").index + + +def validate_hpu_device(location): + return _validate_device(location, "hpu").index + + +def _deserialize(backend_name, obj, location): + if backend_name == "privateuse1": + backend_name = torch._C._get_privateuse1_backend_name() + if location.startswith(backend_name): + device = _validate_device(location, backend_name) + return obj.to(device=device) + + +register_package(10, _cpu_tag, _cpu_deserialize) +register_package( + 20, + functools.partial(_backend_tag, "cuda"), + functools.partial(_deserialize, "cuda"), +) +register_package(21, _mps_tag, _mps_deserialize) +register_package(22, _meta_tag, _meta_deserialize) +register_package( + 23, + functools.partial(_backend_tag, "privateuse1"), + functools.partial(_deserialize, "privateuse1"), +) +register_package( + 24, + functools.partial(_backend_tag, "hpu"), + functools.partial(_deserialize, "hpu"), +) +register_package( + 25, + functools.partial(_backend_tag, "xpu"), + functools.partial(_deserialize, "xpu"), +) +register_package( + 26, + functools.partial(_backend_tag, "mtia"), + functools.partial(_deserialize, "mtia"), +) + + +def location_tag( + storage: Storage | torch.storage.TypedStorage | torch.UntypedStorage, +): + for _, tagger, _ in _package_registry: + location = tagger(storage) + if location: + return location + raise RuntimeError( + "don't know how to determine data location of " + torch.typename(storage) + ) + + +def default_restore_location(storage, location): + """ + Restores `storage` using a deserializer function registered for the `location`. + + This function looks in the registry for deserializer functions that match the `location`. + If found, it attempts to use them, in priority order, to restore `storage` until one + returns a not `None` result. If no deserializer can be found in the registry, or all found fail + to bear a result, it raises a `RuntimeError`. + + Args: + storage (STORAGE): the storage object to restore + location (str): the location tag associated with the storage object + + Returns: + storage: Optional[STORAGE] + + Raises: + RuntimeError: If no deserializer matching `location` is found in the registry or if + all matching ones return `None`. + """ + for _, _, fn in _package_registry: + result = fn(storage, location) + if result is not None: + return result + raise RuntimeError( + "don't know how to restore data location of " + + torch.typename(storage) + + " (tagged with " + + location + + ")" + ) + + +def normalize_storage_type(storage_type): + return getattr(torch, storage_type.__name__) + + +def storage_to_tensor_type(storage): + storage_type = type(storage) + module = _import_dotted_name(storage_type.__module__) + return getattr(module, storage_type.__name__.replace("Storage", "Tensor")) + + +def _is_path(name_or_buffer: object) -> TypeIs[str | os.PathLike]: + return isinstance(name_or_buffer, (str, os.PathLike)) + + +T = TypeVar("T") + + +class _opener(Generic[T]): + def __init__(self, file_like: T) -> None: + self.file_like: T = file_like + + def __enter__(self): + return self.file_like + + def __exit__(self, *args): + pass + + +class _open_file(_opener[IO[bytes]]): + def __init__(self, name: str | os.PathLike[str], mode: str) -> None: + super().__init__(open(name, mode)) # noqa: SIM115 + + def __exit__(self, *args): + self.file_like.close() + + +class _open_buffer_reader(_opener[IO[bytes]]): + def __init__(self, buffer: IO[bytes]) -> None: + super().__init__(buffer) + _check_seekable(buffer) + + +class _open_buffer_writer(_opener[IO[bytes]]): + def __exit__(self, *args): + self.file_like.flush() + + +def _open_file_like(name_or_buffer: FileLike, mode: str) -> _opener[IO[bytes]]: + if _is_path(name_or_buffer): + return _open_file(name_or_buffer, mode) + else: + if "w" in mode: + return _open_buffer_writer(name_or_buffer) + elif "r" in mode: + return _open_buffer_reader(name_or_buffer) + else: + raise RuntimeError(f"Expected 'r' or 'w' in mode but got {mode}") + + +class _open_zipfile_reader(_opener[torch._C.PyTorchFileReader]): + def __init__(self, name_or_buffer: str | IO[bytes]) -> None: + super().__init__(torch._C.PyTorchFileReader(name_or_buffer)) + + +class _open_zipfile_writer_file(_opener[torch._C.PyTorchFileWriter]): + def __init__(self, name: str) -> None: + self.file_stream = None + self.name = name + try: + self.name.encode("ascii") + except UnicodeEncodeError: + # PyTorchFileWriter only supports ascii filename. + # For filenames with non-ascii characters, we rely on Python + # for writing out the file. + # pyrefly: ignore [bad-assignment] + self.file_stream = io.FileIO(self.name, mode="w") + super().__init__( + torch._C.PyTorchFileWriter( # pyrefly: ignore # no-matching-overload + self.file_stream, get_crc32_options(), _get_storage_alignment() + ) + ) + else: + super().__init__( + torch._C.PyTorchFileWriter( + self.name, get_crc32_options(), _get_storage_alignment() + ) + ) + + def __exit__(self, *args) -> None: + self.file_like.write_end_of_file() + if self.file_stream is not None: + self.file_stream.close() + + +class _open_zipfile_writer_buffer(_opener[torch._C.PyTorchFileWriter]): + def __init__(self, buffer: IO[bytes]) -> None: + if not callable(getattr(buffer, "write", None)): + msg = f"Buffer of {str(type(buffer)).strip('<>')} has no callable attribute 'write'" + if not hasattr(buffer, "write"): + raise AttributeError(msg) + raise TypeError(msg) + self.buffer = buffer + super().__init__( + torch._C.PyTorchFileWriter( + buffer, get_crc32_options(), _get_storage_alignment() + ) + ) + + def __exit__(self, *args) -> None: + self.file_like.write_end_of_file() + self.buffer.flush() + + +def _open_zipfile_writer(name_or_buffer: str | IO[bytes]) -> _opener: + container: type[_opener] + if _is_path(name_or_buffer): + container = _open_zipfile_writer_file + else: + container = _open_zipfile_writer_buffer + return container(name_or_buffer) # type: ignore[arg-type] + + +def _is_compressed_file(f) -> bool: + compress_modules = ["gzip"] + try: + return f.__module__ in compress_modules + except AttributeError: + return False + + +def _should_read_directly(f): + """ + Checks if f is a file that should be read directly. It should be read + directly if it is backed by a real file (has a fileno) and is not a + a compressed file (e.g. gzip) + """ + if _is_compressed_file(f): + return False + try: + return f.fileno() >= 0 + except io.UnsupportedOperation: + return False + except AttributeError: + return False + + +def _check_seekable(f) -> bool: + def raise_err_msg(patterns, e): + for p in patterns: + if p in str(e): + msg = ( + str(e) + + ". You can only torch.load from a file that is seekable." + + " Please pre-load the data into a buffer like io.BytesIO and" + + " try to load from it instead." + ) + raise type(e)(msg) + raise e + + try: + f.seek(f.tell()) + return True + except (io.UnsupportedOperation, AttributeError) as e: + raise_err_msg(["seek", "tell"], e) + return False + + +def _check_dill_version(pickle_module) -> None: + """Checks if using dill as the pickle module, and if so, checks if it is the correct version. + If dill version is lower than 0.3.1, a ValueError is raised. + + Args: + pickle_module: module used for pickling metadata and objects + + """ + if pickle_module is not None and pickle_module.__name__ == "dill": + required_dill_version = (0, 3, 1) + if not check_module_version_greater_or_equal( + pickle_module, required_dill_version, False + ): + raise ValueError( + ( + "'torch' supports dill >= {}, but you have dill {}." + " Please upgrade dill or switch to 'pickle'" + ).format( + ".".join([str(num) for num in required_dill_version]), + pickle_module.__version__, + ) + ) + + +def _check_save_filelike(f): + if not _is_path(f) and not hasattr(f, "write"): + raise AttributeError( + "expected 'f' to be string, path, or a file-like object with " + "a 'write' attribute" + ) + + +def save( + obj: object, + f: FileLike, + pickle_module: Any = pickle, + pickle_protocol: int = DEFAULT_PROTOCOL, + _use_new_zipfile_serialization: bool = True, + _disable_byteorder_record: bool = False, +) -> None: + # Reference: https://github.com/pytorch/pytorch/issues/54354 + # The first line of this docstring overrides the one Sphinx generates for the + # documentation. We need it so that Sphinx doesn't leak `pickle`s path from + # the build environment (e.g. `>> # xdoctest: +SKIP("makes cwd dirty") + >>> # Save to file + >>> x = torch.tensor([0, 1, 2, 3, 4]) + >>> torch.save(x, "tensor.pt") + >>> # Save to io.BytesIO buffer + >>> buffer = io.BytesIO() + >>> torch.save(x, buffer) + """ + torch._C._log_api_usage_once("torch.save") + _check_dill_version(pickle_module) + _check_save_filelike(f) + + if isinstance(f, (str, os.PathLike)): + f = os.fspath(f) + + if _use_new_zipfile_serialization: + with _open_zipfile_writer(f) as opened_zipfile: + _save( + obj, + opened_zipfile, + pickle_module, + pickle_protocol, + _disable_byteorder_record, + ) + return + else: + global _serialization_tls + if _serialization_tls.skip_data: + raise RuntimeError( + "Cannot use skip_data=True with _use_new_zipfile_serialization=False" + ) + with _open_file_like(f, "wb") as opened_file: + _legacy_save(obj, opened_file, pickle_module, pickle_protocol) + + +def _legacy_save(obj, f, pickle_module, pickle_protocol) -> None: + import torch.nn as nn + + serialized_container_types = {} + serialized_storages: dict[str, tuple[torch.UntypedStorage, torch.dtype]] = {} + + # Since loading storages that view the same data with different dtypes is + # not supported, we need to keep track of the dtype associated with each + # storage data_ptr and throw an error if the dtype is ever different. + # TODO: This feature could be added in the future + storage_dtypes: dict[int, torch.dtype] = {} + + def persistent_id(obj: Any) -> tuple | None: + # FIXME: the docs say that persistent_id should only return a string + # but torch store returns tuples. This works only in the binary protocol + # see + # https://docs.python.org/2/library/pickle.html#pickling-and-unpickling-external-objects + # https://github.com/python/cpython/blob/master/Lib/pickle.py#L527-L537 + if isinstance(obj, type) and issubclass(obj, nn.Module): + if obj in serialized_container_types: + return None + serialized_container_types[obj] = True + source_file = source = None + try: + source_lines, _, source_file = get_source_lines_and_file(obj) + source = "".join(source_lines) + except ( + Exception + ): # saving the source is optional, so we can ignore any errors + warnings.warn( + "Couldn't retrieve source code for container of " + "type " + obj.__name__ + ". It won't be checked " + "for correctness upon loading.", + stacklevel=2, + ) + return ("module", obj, source_file, source) + + if isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj): + storage: torch.UntypedStorage + + if isinstance(obj, torch.storage.TypedStorage): + # TODO: Once we decide to break serialization FC, this case + # can be deleted + storage = obj._untyped_storage + storage_dtype = obj.dtype + storage_type_str = obj._pickle_storage_type() + storage_type = getattr(torch, storage_type_str) + dtype = obj.dtype + storage_numel = obj._size() + + elif isinstance(obj, torch.UntypedStorage): + storage = obj + storage_dtype = torch.uint8 + storage_type = normalize_storage_type(type(obj)) + dtype = torch.uint8 + storage_numel = storage.nbytes() + else: + raise TypeError(f"type not recognized: {type(obj)}") + + # If storage is allocated, ensure that any other saved storages + # pointing to the same data all have the same dtype. If storage is + # not allocated, don't perform this check + if storage.data_ptr() != 0: + if storage.data_ptr() in storage_dtypes: + if storage_dtype != storage_dtypes[storage.data_ptr()]: + raise RuntimeError( + "Cannot save multiple tensors or storages that " + "view the same data as different types" + ) + else: + storage_dtypes[storage.data_ptr()] = storage_dtype + + view_metadata: tuple[str, int, int] | None + + # Offset is always 0, but we keep it for backwards compatibility + # with the old serialization format (which supported storage views) + offset = 0 + storage_key = str(storage._cdata) + location = location_tag(storage) + + # TODO: There's an issue here with FC. It might be impossible to + # solve, but it's worth noting. Imagine we save a list `[storage, + # tensor]`, where `tensor.storage()` is the same as `storage`, and + # `tensor.element_size() > 1`. Let's say that `tensor.dtype == + # torch.float`. The storage will be serialized with element size + # of 1, since we're choosing to serialize the first occurrence of + # a duplicate storage. Since this legacy serialization format saves + # the numel of the storage, rather than nbytes directly, we'll be + # effectively saving nbytes in this case. We'll be able to load it + # and the tensor back up with no problems in _this_ and future + # versions of pytorch, but in older versions, here's the problem: + # the storage will be loaded up as a UntypedStorage, and then the + # FloatTensor will loaded and the UntypedStorage will be assigned to + # it. Since the storage dtype does not match the tensor dtype, this + # will cause an error. If we reverse the list, like `[tensor, + # storage]`, then we will save the `tensor.storage()` as a faked + # `FloatStorage`, and the saved size will be the correct + # dtype-specific numel count that old versions expect. `tensor` + # will be able to load up properly in old versions, pointing to + # a FloatStorage. However, `storage` is still being translated to + # a UntypedStorage, and it will try to resolve to the same + # FloatStorage that `tensor` contains. This will also cause an + # error. It doesn't seem like there's any way around this. + # Probably, we just cannot maintain FC for the legacy format if the + # saved list contains both a tensor and a storage that point to the + # same data. We should still be able to maintain FC for lists of + # just tensors, as long as all views share the same dtype as the + # tensor they are viewing. + + if storage_key not in serialized_storages: + serialized_storages[storage_key] = (storage, dtype) + is_view = storage._cdata != storage._cdata + if is_view: + view_metadata = (str(storage._cdata), offset, storage.nbytes()) + else: + view_metadata = None + + res = ( + "storage", + storage_type, + storage_key, + location, + storage_numel, + view_metadata, + ) + return res + return None + + sys_info = { + "protocol_version": PROTOCOL_VERSION, + "little_endian": sys.byteorder == "little", + "type_sizes": { + "short": SHORT_SIZE, + "int": INT_SIZE, + "long": LONG_SIZE, + }, + } + + pickle_module.dump(MAGIC_NUMBER, f, protocol=pickle_protocol) + pickle_module.dump(PROTOCOL_VERSION, f, protocol=pickle_protocol) + pickle_module.dump(sys_info, f, protocol=pickle_protocol) + + class PyTorchLegacyPickler(pickle_module.Pickler): + def persistent_id(self, obj): + return persistent_id(obj) # noqa: F821 + + pickler = PyTorchLegacyPickler(f, protocol=pickle_protocol) + pickler.dump(obj) + + # The class def keeps the persistent_id closure alive, leaking memory. + del persistent_id + + serialized_storage_keys = sorted(serialized_storages.keys()) + pickle_module.dump(serialized_storage_keys, f, protocol=pickle_protocol) + f.flush() + for key in serialized_storage_keys: + storage, dtype = serialized_storages[key] + storage._write_file( + f, _should_read_directly(f), True, torch._utils._element_size(dtype) + ) + + +def _save( + obj, + zip_file, + pickle_module, + pickle_protocol, + _disable_byteorder_record, +): + serialized_storages: dict[str, torch.storage.UntypedStorage] = {} + id_map: dict[int, str] = {} + + # Since loading storages that view the same data with different dtypes is + # not supported, we need to keep track of the dtype associated with each + # storage data_ptr and throw an error if the dtype is ever different. + # TODO: This feature could be added in the future + storage_dtypes: dict[int, torch.dtype] = {} + + def persistent_id(obj): + # FIXME: the docs say that persistent_id should only return a string + # but torch store returns tuples. This works only in the binary protocol + # see + # https://docs.python.org/2/library/pickle.html#pickling-and-unpickling-external-objects + # https://github.com/python/cpython/blob/master/Lib/pickle.py#L527-L537 + if isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj): + if isinstance(obj, torch.storage.TypedStorage): + # TODO: Once we decide to break serialization FC, this case + # can be deleted + storage = obj._untyped_storage + storage_dtype = obj.dtype + storage_type_str = obj._pickle_storage_type() + storage_type = getattr(torch, storage_type_str) + storage_numel = obj._size() + + else: + storage = obj + storage_dtype = torch.uint8 + storage_type = normalize_storage_type(type(obj)) + storage_numel = storage.nbytes() + + # If storage is allocated, ensure that any other saved storages + # pointing to the same data all have the same dtype. If storage is + # not allocated, don't perform this check + if str(storage.device) != "meta" and storage.data_ptr() != 0: + if storage.data_ptr() in storage_dtypes: + if storage_dtype != storage_dtypes[storage.data_ptr()]: + raise RuntimeError( + "Cannot save multiple tensors or storages that " + "view the same data as different types" + ) + else: + storage_dtypes[storage.data_ptr()] = storage_dtype + + storage_key = id_map.setdefault(storage._cdata, str(len(id_map))) + if hasattr(obj, "_fake_device") and obj._fake_device is not None: + location = str(obj._fake_device) + else: + location = location_tag(storage) + serialized_storages[storage_key] = storage + + return ("storage", storage_type, storage_key, location, storage_numel) + + return None + + # Write the pickle data for `obj` + data_buf = io.BytesIO() + + class PyTorchPickler(pickle_module.Pickler): # type: ignore[name-defined] + def persistent_id(self, obj): + return persistent_id(obj) # noqa: F821 + + pickler = PyTorchPickler(data_buf, protocol=pickle_protocol) + pickler.dump(obj) + + # The class def keeps the persistent_id closure alive, leaking memory. + del persistent_id + + data_value = data_buf.getvalue() + zip_file.write_record("data.pkl", data_value, len(data_value)) + # .format_version is used to track + # 1. version 1 represents the order of storages being changed from + # lexicographical based on keys to numerically ordered based on keys + # 2. version 2 represents including storage_alignment as a record + # within the zipfile + zip_file.write_record(".format_version", "1", len("1")) + storage_alignment = str(_get_storage_alignment()) + zip_file.write_record( + ".storage_alignment", storage_alignment, len(storage_alignment) + ) + + # Write byte order marker + if not _disable_byteorder_record: + if sys.byteorder not in ["little", "big"]: + raise ValueError("Unknown endianness type: " + sys.byteorder) + + zip_file.write_record("byteorder", sys.byteorder, len(sys.byteorder)) + + # Write each tensor to a file named tensor/the_tensor_key in the zip archive + for key in serialized_storages: + name = f"data/{key}" + storage = serialized_storages[key] + num_bytes = storage.nbytes() + global _serialization_tls + if _serialization_tls.skip_data: + zip_file.write_record_metadata(name, num_bytes) + else: + # given that we copy things around anyway, we might use storage.cpu() + # this means to that to get tensors serialized, you need to implement + # .cpu() on the underlying Storage + if storage.device.type != "cpu": + from torch.utils.serialization import config + + if ( + config.save.use_pinned_memory_for_d2h + and ( + acc := torch.accelerator.current_accelerator( + check_available=True + ) + ) + is not None + and acc.type == storage.device.type + ): + new_storage = torch.empty( + num_bytes, dtype=torch.uint8, device="cpu", pin_memory=True + ).untyped_storage() + new_storage.copy_(storage) + torch.accelerator.current_stream(storage.device.index).synchronize() + storage = new_storage + else: + storage = storage.cpu() + # Now that it is on the CPU we can directly copy it into the zip file + zip_file.write_record(name, storage, num_bytes) + + +def load( + f: FileLike, + map_location: MAP_LOCATION = None, + pickle_module: Any = None, + *, + weights_only: bool | None = None, + mmap: bool | None = None, + **pickle_load_args: Any, +) -> Any: + # Reference: https://github.com/pytorch/pytorch/issues/54354 + # The first line of this docstring overrides the one Sphinx generates for the + # documentation. We need it so that Sphinx doesn't leak `pickle`s path from + # the build environment (e.g. `>> # xdoctest: +SKIP("undefined filepaths") + >>> torch.load("tensors.pt", weights_only=True) + # Load all tensors onto the CPU + >>> torch.load( + ... "tensors.pt", + ... map_location=torch.device("cpu"), + ... weights_only=True, + ... ) + # Load all tensors onto the CPU, using a function + >>> torch.load( + ... "tensors.pt", + ... map_location=lambda storage, loc: storage, + ... weights_only=True, + ... ) + # Load all tensors onto GPU 1 + >>> torch.load( + ... "tensors.pt", + ... map_location=lambda storage, loc: storage.cuda(1), # type: ignore[attr-defined] + ... weights_only=True, + ... ) # type: ignore[attr-defined] + # Map tensors from GPU 1 to GPU 0 + >>> torch.load( + ... "tensors.pt", + ... map_location={"cuda:1": "cuda:0"}, + ... weights_only=True, + ... ) + # Load tensor from io.BytesIO object + # Loading from a buffer setting weights_only=False, warning this can be unsafe + >>> with open("tensor.pt", "rb") as f: + ... buffer = io.BytesIO(f.read()) + >>> torch.load(buffer, weights_only=False) + # Load a module with 'ascii' encoding for unpickling + # Loading from a module setting weights_only=False, warning this can be unsafe + >>> torch.load("module.pt", encoding="ascii", weights_only=False) + """ + torch._C._log_api_usage_once("torch.load") + DOCS_MESSAGE = ( + "\n\nCheck the documentation of torch.load to learn more about types accepted by default with " + "weights_only https://pytorch.org/docs/stable/generated/torch.load.html." + ) + + def _get_wo_message(message: str) -> str: + unsafe_global_pattern = r"GLOBAL (\S+) was not an allowed global by default." + has_unsafe_global = re.search(unsafe_global_pattern, message) is not None + blocklist_pattern = r"whose module (\S+) is blocked" + has_blocklist = re.search(blocklist_pattern, message) is not None + import_pattern = r"(\S+) must be (\S+) to load" + has_import = re.search(import_pattern, message) is not None + if has_unsafe_global: + updated_message = ( + "Weights only load failed. This file can still be loaded, to do so you have two options, " + "\033[1mdo those steps only if you trust the source of the checkpoint\033[0m. " + f"\n\t(1) {UNSAFE_MESSAGE}\n\t(2) Alternatively, to load with `weights_only=True` please check " + "the recommended steps in the following error message.\n\tWeightsUnpickler error: " + + message + ) + else: + if has_import: + return f"Weights only load failed. {message}\n {UNSAFE_MESSAGE}\n" + else: + updated_message = f"Weights only load failed. {UNSAFE_MESSAGE}\n" + if not has_blocklist: + updated_message += ( + "Please file an issue with the following so that we can make " + "`weights_only=True` compatible with your use case: WeightsUnpickler error: " + ) + updated_message += "\n\n" + message + return updated_message + DOCS_MESSAGE + + weights_only_not_set = weights_only is None + + if weights_only_not_set: + weights_only = _default_to_weights_only(pickle_module) + + true_values = ["1", "y", "yes", "true"] + # Add ability to force safe only or non-safe weight loads via environment variables + force_weights_only_load = ( + os.getenv("TORCH_FORCE_WEIGHTS_ONLY_LOAD", "0") in true_values + ) + force_no_weights_only_load = ( + os.getenv("TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD", "0") in true_values + ) + + if force_weights_only_load and force_no_weights_only_load: + raise RuntimeError( + "Only one of `TORCH_FORCE_WEIGHTS_ONLY_LOAD` or `TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD` " + "should be set, but both were set." + ) + elif force_weights_only_load: + weights_only = True + elif force_no_weights_only_load: + # TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD can only override if callsite did not explicitly set weights_only + if weights_only_not_set: + warnings.warn( + "Environment variable TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD detected, since the" + "`weights_only` argument was not explicitly passed to `torch.load`, forcing weights_only=False.", + UserWarning, + stacklevel=2, + ) + weights_only = False + + if weights_only: + if pickle_module is not None: + raise RuntimeError( + "Can not safely load weights when explicit pickle_module is specified" + ) + else: + if pickle_module is None: + pickle_module = pickle + + # make flipping default BC-compatible + if mmap is None: + from torch.utils.serialization import config + + mmap = config.load.mmap + + _check_dill_version(pickle_module) + + if "encoding" not in pickle_load_args: + pickle_load_args["encoding"] = "utf-8" + + with _open_file_like(f, "rb") as opened_file: + if _is_zipfile(opened_file): + # The zipfile reader is going to advance the current file position. + # If we want to actually tail call to torch.jit.load, we need to + # reset back to the original position. + orig_position = opened_file.tell() + overall_storage = None + with _open_zipfile_reader(opened_file) as opened_zipfile: + if _is_torchscript_zip(opened_zipfile): + warnings.warn( + "'torch.load' received a zip file that looks like a TorchScript archive" + " dispatching to 'torch.jit.load' (call 'torch.jit.load' directly to" + " silence this warning)", + UserWarning, + stacklevel=2, + ) + if weights_only: + raise RuntimeError( + "Cannot use ``weights_only=True`` with TorchScript archives passed to " + "``torch.load``. " + UNSAFE_MESSAGE + ) + opened_file.seek(orig_position) + return torch.jit.load(opened_file, map_location=map_location) + if mmap: + if not _is_path(f): + raise ValueError( + "f must be a file path in order to use the mmap argument" + ) + size = os.path.getsize(f) + if not IS_WINDOWS: + shared = get_default_mmap_options() == MAP_SHARED + else: + shared = False + overall_storage = torch.UntypedStorage.from_file( + os.fspath(f), + shared, + size, + ) + if weights_only: + try: + return _load( + opened_zipfile, + map_location, + _weights_only_unpickler, + overall_storage=overall_storage, + **pickle_load_args, + ) + except pickle.UnpicklingError as e: + raise pickle.UnpicklingError(_get_wo_message(str(e))) from None + return _load( + opened_zipfile, + map_location, + pickle_module, + overall_storage=overall_storage, + **pickle_load_args, + ) + if mmap: + f_name = "" if not isinstance(f, str) else f"{f}, " + raise RuntimeError( + "mmap can only be used with files saved with " + f"`torch.save({f_name}_use_new_zipfile_serialization=True), " + "please torch.save your checkpoint with this option in order to use mmap." + ) + if weights_only: + try: + return _legacy_load( + opened_file, + map_location, + _weights_only_unpickler, + **pickle_load_args, + ) + except pickle.UnpicklingError as e: + raise pickle.UnpicklingError(_get_wo_message(str(e))) from None + return _legacy_load( + opened_file, map_location, pickle_module, **pickle_load_args + ) + + +# Register pickling support for layout instances such as +# torch.sparse_coo, etc +def _get_layout(name): + """Get layout extension object from its string representation.""" + cache = _get_layout.cache # type: ignore[attr-defined] + if not cache: + for v in torch.__dict__.values(): + if isinstance(v, torch.layout): + cache[str(v)] = v + return cache[name] + + +# There are yet not good way to type annotate function attributes https://github.com/python/mypy/issues/2087 +_get_layout.cache = {} # type: ignore[attr-defined] +copyreg.pickle(torch.layout, lambda obj: (_get_layout, (str(obj),))) + + +def _legacy_load(f, map_location, pickle_module, **pickle_load_args): + deserialized_objects: dict[int, Any] = {} + + restore_location = _get_restore_location(map_location) + + class UnpicklerWrapper(pickle_module.Unpickler): # type: ignore[name-defined] + def find_class(self, mod_name, name): + if type(name) is str and "Storage" in name: + try: + return StorageType(name) + except KeyError: + pass + return super().find_class(mod_name, name) + + def _check_container_source(container_type, source_file, original_source): + try: + current_source = "".join(get_source_lines_and_file(container_type)[0]) + except Exception: # saving the source is optional, so we can ignore any errors + warnings.warn( + "Couldn't retrieve source code for container of " + "type " + container_type.__name__ + ". It won't be checked " + "for correctness upon loading.", + stacklevel=2, + ) + return + if original_source != current_source: + if container_type.dump_patches: + file_name = container_type.__name__ + ".patch" + diff = difflib.unified_diff( + current_source.split("\n"), + original_source.split("\n"), + source_file, + source_file, + lineterm="", + ) + lines = "\n".join(diff) + try: + with open(file_name, "a+") as f: + file_size = f.seek(0, 2) + f.seek(0) + if file_size == 0: + f.write(lines) + elif file_size != len(lines) or f.read() != lines: + raise OSError + msg = ( + "Saved a reverse patch to " + file_name + ". " + "Run `patch -p0 < " + file_name + "` to revert your " + "changes." + ) + except OSError: + msg = ( + "Tried to save a patch, but couldn't create a " + "writable file " + file_name + ". Make sure it " + "doesn't exist and your working directory is " + "writable." + ) + else: + msg = ( + "you can retrieve the original source code by " + "accessing the object's source attribute or set " + "`torch.nn.Module.dump_patches = True` and use the " + "patch tool to revert the changes." + ) + msg = f"source code of class '{torch.typename(container_type)}' has changed. {msg}" + warnings.warn(msg, SourceChangeWarning, stacklevel=2) + + def legacy_load(f): + deserialized_objects: dict[int, Any] = {} + + def persistent_load(saved_id): + if isinstance(saved_id, tuple): + # Ignore containers that don't have any sources saved + if all(saved_id[1:]): + _check_container_source(*saved_id) + return saved_id[0] + return deserialized_objects[int(saved_id)] + + with ( + closing( + tarfile.open(fileobj=f, mode="r:", format=tarfile.PAX_FORMAT) + ) as tar, + mkdtemp() as tmpdir, + ): + if pickle_module is _weights_only_unpickler: + raise RuntimeError( + "Cannot use ``weights_only=True`` with files saved in the " + "legacy .tar format. " + UNSAFE_MESSAGE + ) + tar.extract("storages", path=tmpdir) + with open(os.path.join(tmpdir, "storages"), "rb", 0) as f: + num_storages = pickle_module.load(f, **pickle_load_args) + for _ in range(num_storages): + args = pickle_module.load(f, **pickle_load_args) + key, location, storage_type = args + dtype = storage_type._dtype + obj = cast(Storage, torch.UntypedStorage)._new_with_file( + f, torch._utils._element_size(dtype) + ) + obj = restore_location(obj, location) + # TODO: Once we decide to break serialization FC, we can + # stop wrapping with TypedStorage + deserialized_objects[key] = torch.storage.TypedStorage( + wrap_storage=obj, dtype=dtype, _internal=True + ) + + storage_views = pickle_module.load(f, **pickle_load_args) + for target_cdata, root_cdata, offset, numel in storage_views: + root = deserialized_objects[root_cdata] + element_size = torch._utils._element_size(root.dtype) + offset_bytes = offset * element_size + # TODO: Once we decide to break serialization FC, we can + # stop wrapping with TypedStorage + deserialized_objects[target_cdata] = torch.storage.TypedStorage( + wrap_storage=root._untyped_storage[ + offset_bytes : offset_bytes + numel * element_size + ], + dtype=root.dtype, + _internal=True, + ) + + tar.extract("tensors", path=tmpdir) + with open(os.path.join(tmpdir, "tensors"), "rb", 0) as f: + num_tensors = pickle_module.load(f, **pickle_load_args) + for _ in range(num_tensors): + args = pickle_module.load(f, **pickle_load_args) + key, storage_id, _original_tensor_type = args + storage = deserialized_objects[storage_id] + (ndim,) = struct.unpack(" str: + # When using encoding='bytes' in Py3, some **internal** keys stored as + # strings in Py2 are loaded as bytes. This function decodes them with + # ascii encoding, one that Py3 uses by default. + # + # NOTE: This should only be used on internal keys (e.g., `typename` and + # `location` in `persistent_load` below! + if isinstance(bytes_str, bytes): + return bytes_str.decode("ascii") + return bytes_str + + +def _get_restore_location(map_location): + if map_location is None: + restore_location = default_restore_location + elif isinstance(map_location, dict): + + def restore_location(storage, location): + location = map_location.get(location, location) + return default_restore_location(storage, location) + + elif isinstance(map_location, (str, bytes)): + + def restore_location(storage, location): + return default_restore_location(storage, map_location) + + elif isinstance(map_location, torch.device): + + def restore_location(storage, location): + return default_restore_location(storage, str(map_location)) + + else: + + def restore_location(storage, location): + result = map_location(storage, location) + if result is None: + result = default_restore_location(storage, location) + return result + + return restore_location + + +class StorageType: + def __init__(self, name): + self._dtype = _get_dtype_from_pickle_storage_type(name) + + @property + def dtype(self): + return self._dtype + + def __str__(self): + return f"StorageType(dtype={self.dtype})" + + +def _load( + zip_file, + map_location, + pickle_module, + pickle_file="data.pkl", + overall_storage=None, + **pickle_load_args, +): + restore_location = _get_restore_location(map_location) + + loaded_storages = {} + + can_calculate_storage_offsets = False + if zip_file.has_record(".format_version"): + version = zip_file.get_record(".format_version") + can_calculate_storage_offsets = version >= b"1" + + # check if byteswapping is needed + byteordername = "byteorder" + byteorderdata = None + if zip_file.has_record(byteordername): + byteorderdata = zip_file.get_record(byteordername) + if byteorderdata not in [b"little", b"big"]: + raise ValueError("Unknown endianness type: " + byteorderdata.decode()) + elif ( + get_default_load_endianness() == LoadEndianness.LITTLE + or get_default_load_endianness() is None + ): + byteorderdata = b"little" + elif get_default_load_endianness() == LoadEndianness.BIG: + byteorderdata = b"big" + elif get_default_load_endianness() == LoadEndianness.NATIVE: + pass + else: + raise ValueError("Invalid load endianness type") + + storage_alignment = 64 + if zip_file.has_record(".storage_alignment"): + storage_alignment = int(zip_file.get_record(".storage_alignment")) + + if ( + not zip_file.has_record(byteordername) + and get_default_load_endianness() is None + and sys.byteorder == "big" + ): + # Default behaviour was changed + # See https://github.com/pytorch/pytorch/issues/101688 + warnings.warn( + "The default load endianness for checkpoints without a byteorder mark " + "on big endian machines was changed from 'native' to 'little' endian, " + "to avoid this behavior please use " + "torch.serialization.set_default_load_endianness to set " + "the desired default load endianness", + UserWarning, + stacklevel=2, + ) + + from torch.utils.serialization import config + + calculate_storage_offsets = config.load.calculate_storage_offsets + run_debug_asserts = os.environ.get("TORCH_SERIALIZATION_DEBUG", "0") == "1" + current_offset = None + # constants from miniz.h/miniz.c + data_descripter_size64 = 24 + data_descripter_size32 = 16 + mz_uint32_max = 0xFFFFFFFF + offsets: dict[str, int] = dict() + + def _get_offset(key, name, numel): + """ + Return the offset of the storage associated with key with record name `name` and size numel. + It is expected that the zipfile header of this storage starts at current_offset. + + WARNING: This function relies on the behavior of the zipwriter in miniz.c. In particular, + the behavior of `mz_zip_writer_add_mem_ex_v2`. The behavior of this function must be kept + in sync with that of miniz! + + After reading a storage of size numel that starts at storage_offset + if it is the first time that storage was read, update nonlocal variable + current_offset to the start of the next zipfile header by incrementing + it by numel and the data descriptor size. + """ + nonlocal current_offset, offsets + if name in offsets: + storage_offset = offsets[name] + return storage_offset + + if current_offset is None: + assert key == "0" + current_offset = zip_file.get_record_offset(name) + local_header_offset = zip_file.get_record_header_offset(name) + storage_offset = current_offset + else: + storage_offset = zip_file.get_record_offset_no_read( + current_offset, name, numel, storage_alignment + ) + local_header_offset = current_offset + + # This is only actually needed for storages that have typed_storage._data_ptr() == 0 + # after being read. Otherwise persistent_load would never "re-call" load_tensor + # for a given key. + offsets[name] = storage_offset + + # Increment current_offset to offset where next zipfile header starts + current_offset = storage_offset + numel + # add size of data descriptor after payload + if numel > 0: + if local_header_offset >= mz_uint32_max or numel >= mz_uint32_max: + current_offset += data_descripter_size64 + else: + current_offset += data_descripter_size32 + + return storage_offset + + def load_tensor(dtype, numel, key, location): + name = f"data/{key}" + if torch._guards.detect_fake_mode(None) is not None: + nbytes = numel * torch._utils._element_size(dtype) + storage = torch.UntypedStorage(nbytes, device="meta") + if can_calculate_storage_offsets: + storage._checkpoint_offset = _get_offset(key, name, numel) + else: + storage._checkpoint_offset = zip_file.get_record_offset(name) + elif _serialization_tls.skip_data: + nbytes = numel * torch._utils._element_size(dtype) + storage = torch.UntypedStorage(nbytes) + elif overall_storage is not None: + if can_calculate_storage_offsets and calculate_storage_offsets: + storage_offset = _get_offset(key, name, numel) + if run_debug_asserts: + if storage_offset != zip_file.get_record_offset(name): + raise RuntimeError( + "This is a debug assert that was run as the `TORCH_SERIALIZATION_DEBUG` environment " + f"variable was set: Incorrect offset for {name}, got {storage_offset} expected " + f"{zip_file.get_record_offset(name)}" + ) + else: + storage_offset = zip_file.get_record_offset(name) + storage = overall_storage[storage_offset : storage_offset + numel] + else: + if can_calculate_storage_offsets and run_debug_asserts: + # This is debug code that we use to test the validity of + # torch.utils.serialization.config.load.calculate_storage_offsets throughout CI + storage_offset = _get_offset(key, name, numel) + if storage_offset != zip_file.get_record_offset(name): + raise RuntimeError( + "This is a debug assert that was run as the `TORCH_SERIALIZATION_DEBUG` environment " + f"variable was set: Incorrect offset for {name}, got {storage_offset} expected " + f"{zip_file.get_record_offset(name)}" + ) + storage = ( + zip_file.get_storage_from_record(name, numel, torch.UntypedStorage) + ._typed_storage() + ._untyped_storage + ) + # swap here if byteswapping is needed + if byteorderdata is not None: + if byteorderdata.decode() != sys.byteorder: + storage.byteswap(dtype) + + # TODO: Once we decide to break serialization FC, we can + # stop wrapping with TypedStorage + + if torch._guards.detect_fake_mode(None) is None: + wrap_storage = restore_location(storage, location) + else: + storage._fake_device = location + wrap_storage = storage + + typed_storage = torch.storage.TypedStorage( + wrap_storage=wrap_storage, + dtype=dtype, + _internal=True, + ) + + if typed_storage._data_ptr() != 0: + loaded_storages[key] = typed_storage + + return typed_storage + + def persistent_load(saved_id): + assert isinstance(saved_id, tuple) + typename = _maybe_decode_ascii(saved_id[0]) + data = saved_id[1:] + + assert typename == "storage", ( + f"Unknown typename for persistent_load, expected 'storage' but got '{typename}'" + ) + storage_type, key, location, numel = data + if storage_type is torch.UntypedStorage: + dtype = torch.uint8 + else: + dtype = storage_type.dtype + + if key in loaded_storages: + typed_storage = loaded_storages[key] + else: + nbytes = numel * torch._utils._element_size(dtype) + typed_storage = load_tensor( + dtype, nbytes, key, _maybe_decode_ascii(location) + ) + + return typed_storage + + load_module_mapping: dict[str, str] = { + # See https://github.com/pytorch/pytorch/pull/51633 + "torch.tensor": "torch._tensor" + } + + # Need to subclass Unpickler instead of directly monkey-patching the find_class method + # because it's marked readonly in pickle. + # The type: ignore is because mypy can't statically determine the type of this class. + class UnpicklerWrapper(pickle_module.Unpickler): # type: ignore[name-defined] + # from https://stackoverflow.com/questions/13398462/unpickling-python-objects-with-a-changed-module-path/13405732 + # Lets us override the imports that pickle uses when unpickling an object. + # This is useful for maintaining BC if we change a module path that tensor instantiation relies on. + def find_class(self, mod_name, name): + if type(name) is str and "Storage" in name: + try: + return StorageType(name) + except KeyError: + pass + mod_name = load_module_mapping.get(mod_name, mod_name) + return super().find_class(mod_name, name) + + # Load the data (which may in turn use `persistent_load` to load tensors) + data_file = io.BytesIO(zip_file.get_record(pickle_file)) + + unpickler = UnpicklerWrapper(data_file, **pickle_load_args) + unpickler.persistent_load = persistent_load + # Needed for tensors where storage device and rebuild tensor device are + # not connected (wrapper subclasses and tensors rebuilt using numpy) + global _serialization_tls + _serialization_tls.map_location = map_location + result = unpickler.load() + _serialization_tls.map_location = None + + torch._utils._validate_loaded_sparse_tensors() + torch._C._log_api_usage_metadata( + "torch.load.metadata", {"serialization_id": zip_file.serialization_id()} + ) + return result + + +def _is_torchscript_zip(zip_file): + return "constants.pkl" in zip_file.get_all_records() diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/typer/_completion_classes.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/typer/_completion_classes.py new file mode 100644 index 0000000000000000000000000000000000000000..070bbaa21410f85e526c9e15a5a677dec7f874d6 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/typer/_completion_classes.py @@ -0,0 +1,203 @@ +import importlib.util +import os +import re +import sys +from typing import Any, Dict, List, Tuple + +import click +import click.parser +import click.shell_completion + +from ._completion_shared import ( + COMPLETION_SCRIPT_BASH, + COMPLETION_SCRIPT_FISH, + COMPLETION_SCRIPT_POWER_SHELL, + COMPLETION_SCRIPT_ZSH, + Shells, +) + +try: + import shellingham +except ImportError: # pragma: no cover + shellingham = None + + +def _sanitize_help_text(text: str) -> str: + """Sanitizes the help text by removing rich tags""" + if not importlib.util.find_spec("rich"): + return text + from . import rich_utils + + return rich_utils.rich_render_text(text) + + +class BashComplete(click.shell_completion.BashComplete): + name = Shells.bash.value + source_template = COMPLETION_SCRIPT_BASH + + def source_vars(self) -> Dict[str, Any]: + return { + "complete_func": self.func_name, + "autocomplete_var": self.complete_var, + "prog_name": self.prog_name, + } + + def get_completion_args(self) -> Tuple[List[str], str]: + cwords = click.parser.split_arg_string(os.environ["COMP_WORDS"]) + cword = int(os.environ["COMP_CWORD"]) + args = cwords[1:cword] + + try: + incomplete = cwords[cword] + except IndexError: + incomplete = "" + + return args, incomplete + + def format_completion(self, item: click.shell_completion.CompletionItem) -> str: + # TODO: Explore replicating the new behavior from Click, with item types and + # triggering completion for files and directories + # return f"{item.type},{item.value}" + return f"{item.value}" + + def complete(self) -> str: + args, incomplete = self.get_completion_args() + completions = self.get_completions(args, incomplete) + out = [self.format_completion(item) for item in completions] + return "\n".join(out) + + +class ZshComplete(click.shell_completion.ZshComplete): + name = Shells.zsh.value + source_template = COMPLETION_SCRIPT_ZSH + + def source_vars(self) -> Dict[str, Any]: + return { + "complete_func": self.func_name, + "autocomplete_var": self.complete_var, + "prog_name": self.prog_name, + } + + def get_completion_args(self) -> Tuple[List[str], str]: + completion_args = os.getenv("_TYPER_COMPLETE_ARGS", "") + cwords = click.parser.split_arg_string(completion_args) + args = cwords[1:] + if args and not completion_args.endswith(" "): + incomplete = args[-1] + args = args[:-1] + else: + incomplete = "" + return args, incomplete + + def format_completion(self, item: click.shell_completion.CompletionItem) -> str: + def escape(s: str) -> str: + return ( + s.replace('"', '""') + .replace("'", "''") + .replace("$", "\\$") + .replace("`", "\\`") + .replace(":", r"\\:") + ) + + # TODO: Explore replicating the new behavior from Click, pay attention to + # the difference with and without escape + # return f"{item.type}\n{item.value}\n{item.help if item.help else '_'}" + if item.help: + return f'"{escape(item.value)}":"{_sanitize_help_text(escape(item.help))}"' + else: + return f'"{escape(item.value)}"' + + def complete(self) -> str: + args, incomplete = self.get_completion_args() + completions = self.get_completions(args, incomplete) + res = [self.format_completion(item) for item in completions] + if res: + args_str = "\n".join(res) + return f"_arguments '*: :(({args_str}))'" + else: + return "_files" + + +class FishComplete(click.shell_completion.FishComplete): + name = Shells.fish.value + source_template = COMPLETION_SCRIPT_FISH + + def source_vars(self) -> Dict[str, Any]: + return { + "complete_func": self.func_name, + "autocomplete_var": self.complete_var, + "prog_name": self.prog_name, + } + + def get_completion_args(self) -> Tuple[List[str], str]: + completion_args = os.getenv("_TYPER_COMPLETE_ARGS", "") + cwords = click.parser.split_arg_string(completion_args) + args = cwords[1:] + if args and not completion_args.endswith(" "): + incomplete = args[-1] + args = args[:-1] + else: + incomplete = "" + return args, incomplete + + def format_completion(self, item: click.shell_completion.CompletionItem) -> str: + # TODO: Explore replicating the new behavior from Click, pay attention to + # the difference with and without formatted help + # if item.help: + # return f"{item.type},{item.value}\t{item.help}" + + # return f"{item.type},{item.value} + if item.help: + formatted_help = re.sub(r"\s", " ", item.help) + return f"{item.value}\t{_sanitize_help_text(formatted_help)}" + else: + return f"{item.value}" + + def complete(self) -> str: + complete_action = os.getenv("_TYPER_COMPLETE_FISH_ACTION", "") + args, incomplete = self.get_completion_args() + completions = self.get_completions(args, incomplete) + show_args = [self.format_completion(item) for item in completions] + if complete_action == "get-args": + if show_args: + return "\n".join(show_args) + elif complete_action == "is-args": + if show_args: + # Activate complete args (no files) + sys.exit(0) + else: + # Deactivate complete args (allow files) + sys.exit(1) + return "" # pragma: no cover + + +class PowerShellComplete(click.shell_completion.ShellComplete): + name = Shells.powershell.value + source_template = COMPLETION_SCRIPT_POWER_SHELL + + def source_vars(self) -> Dict[str, Any]: + return { + "complete_func": self.func_name, + "autocomplete_var": self.complete_var, + "prog_name": self.prog_name, + } + + def get_completion_args(self) -> Tuple[List[str], str]: + completion_args = os.getenv("_TYPER_COMPLETE_ARGS", "") + incomplete = os.getenv("_TYPER_COMPLETE_WORD_TO_COMPLETE", "") + cwords = click.parser.split_arg_string(completion_args) + args = cwords[1:-1] if incomplete else cwords[1:] + return args, incomplete + + def format_completion(self, item: click.shell_completion.CompletionItem) -> str: + return f"{item.value}:::{_sanitize_help_text(item.help) if item.help else ' '}" + + +def completion_init() -> None: + click.shell_completion.add_completion_class(BashComplete, Shells.bash.value) + click.shell_completion.add_completion_class(ZshComplete, Shells.zsh.value) + click.shell_completion.add_completion_class(FishComplete, Shells.fish.value) + click.shell_completion.add_completion_class( + PowerShellComplete, Shells.powershell.value + ) + click.shell_completion.add_completion_class(PowerShellComplete, Shells.pwsh.value) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/typer/_completion_shared.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/typer/_completion_shared.py new file mode 100644 index 0000000000000000000000000000000000000000..cc0add992c722c6044342d912f2772aedca86538 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/typer/_completion_shared.py @@ -0,0 +1,240 @@ +import os +import re +import subprocess +from enum import Enum +from pathlib import Path +from typing import Optional, Tuple + +import click + +try: + import shellingham +except ImportError: # pragma: no cover + shellingham = None + + +class Shells(str, Enum): + bash = "bash" + zsh = "zsh" + fish = "fish" + powershell = "powershell" + pwsh = "pwsh" + + +COMPLETION_SCRIPT_BASH = """ +%(complete_func)s() { + local IFS=$'\n' + COMPREPLY=( $( env COMP_WORDS="${COMP_WORDS[*]}" \\ + COMP_CWORD=$COMP_CWORD \\ + %(autocomplete_var)s=complete_bash $1 ) ) + return 0 +} + +complete -o default -F %(complete_func)s %(prog_name)s +""" + +COMPLETION_SCRIPT_ZSH = """ +#compdef %(prog_name)s + +%(complete_func)s() { + eval $(env _TYPER_COMPLETE_ARGS="${words[1,$CURRENT]}" %(autocomplete_var)s=complete_zsh %(prog_name)s) +} + +compdef %(complete_func)s %(prog_name)s +""" + +COMPLETION_SCRIPT_FISH = 'complete --command %(prog_name)s --no-files --arguments "(env %(autocomplete_var)s=complete_fish _TYPER_COMPLETE_FISH_ACTION=get-args _TYPER_COMPLETE_ARGS=(commandline -cp) %(prog_name)s)" --condition "env %(autocomplete_var)s=complete_fish _TYPER_COMPLETE_FISH_ACTION=is-args _TYPER_COMPLETE_ARGS=(commandline -cp) %(prog_name)s"' + +COMPLETION_SCRIPT_POWER_SHELL = """ +Import-Module PSReadLine +Set-PSReadLineKeyHandler -Chord Tab -Function MenuComplete +$scriptblock = { + param($wordToComplete, $commandAst, $cursorPosition) + $Env:%(autocomplete_var)s = "complete_powershell" + $Env:_TYPER_COMPLETE_ARGS = $commandAst.ToString() + $Env:_TYPER_COMPLETE_WORD_TO_COMPLETE = $wordToComplete + %(prog_name)s | ForEach-Object { + $commandArray = $_ -Split ":::" + $command = $commandArray[0] + $helpString = $commandArray[1] + [System.Management.Automation.CompletionResult]::new( + $command, $command, 'ParameterValue', $helpString) + } + $Env:%(autocomplete_var)s = "" + $Env:_TYPER_COMPLETE_ARGS = "" + $Env:_TYPER_COMPLETE_WORD_TO_COMPLETE = "" +} +Register-ArgumentCompleter -Native -CommandName %(prog_name)s -ScriptBlock $scriptblock +""" + +_completion_scripts = { + "bash": COMPLETION_SCRIPT_BASH, + "zsh": COMPLETION_SCRIPT_ZSH, + "fish": COMPLETION_SCRIPT_FISH, + "powershell": COMPLETION_SCRIPT_POWER_SHELL, + "pwsh": COMPLETION_SCRIPT_POWER_SHELL, +} + +# TODO: Probably refactor this, copied from Click 7.x +_invalid_ident_char_re = re.compile(r"[^a-zA-Z0-9_]") + + +def get_completion_script(*, prog_name: str, complete_var: str, shell: str) -> str: + cf_name = _invalid_ident_char_re.sub("", prog_name.replace("-", "_")) + script = _completion_scripts.get(shell) + if script is None: + click.echo(f"Shell {shell} not supported.", err=True) + raise click.exceptions.Exit(1) + return ( + script + % { + "complete_func": f"_{cf_name}_completion", + "prog_name": prog_name, + "autocomplete_var": complete_var, + } + ).strip() + + +def install_bash(*, prog_name: str, complete_var: str, shell: str) -> Path: + # Ref: https://github.com/scop/bash-completion#faq + # It seems bash-completion is the official completion system for bash: + # Ref: https://www.gnu.org/software/bash/manual/html_node/A-Programmable-Completion-Example.html + # But installing in the locations from the docs doesn't seem to have effect + completion_path = Path.home() / ".bash_completions" / f"{prog_name}.sh" + rc_path = Path.home() / ".bashrc" + rc_path.parent.mkdir(parents=True, exist_ok=True) + rc_content = "" + if rc_path.is_file(): + rc_content = rc_path.read_text() + completion_init_lines = [f"source '{completion_path}'"] + for line in completion_init_lines: + if line not in rc_content: # pragma: no cover + rc_content += f"\n{line}" + rc_content += "\n" + rc_path.write_text(rc_content) + # Install completion + completion_path.parent.mkdir(parents=True, exist_ok=True) + script_content = get_completion_script( + prog_name=prog_name, complete_var=complete_var, shell=shell + ) + completion_path.write_text(script_content) + return completion_path + + +def install_zsh(*, prog_name: str, complete_var: str, shell: str) -> Path: + # Setup Zsh and load ~/.zfunc + zshrc_path = Path.home() / ".zshrc" + zshrc_path.parent.mkdir(parents=True, exist_ok=True) + zshrc_content = "" + if zshrc_path.is_file(): + zshrc_content = zshrc_path.read_text() + completion_line = "fpath+=~/.zfunc; autoload -Uz compinit; compinit" + if completion_line not in zshrc_content: + zshrc_content += f"\n{completion_line}\n" + style_line = "zstyle ':completion:*' menu select" + # TODO: consider setting the style only for the current program + # style_line = f"zstyle ':completion:*:*:{prog_name}:*' menu select" + # Install zstyle completion config only if the user doesn't have a customization + if "zstyle" not in zshrc_content: + zshrc_content += f"\n{style_line}\n" + zshrc_content = f"{zshrc_content.strip()}\n" + zshrc_path.write_text(zshrc_content) + # Install completion under ~/.zfunc/ + path_obj = Path.home() / f".zfunc/_{prog_name}" + path_obj.parent.mkdir(parents=True, exist_ok=True) + script_content = get_completion_script( + prog_name=prog_name, complete_var=complete_var, shell=shell + ) + path_obj.write_text(script_content) + return path_obj + + +def install_fish(*, prog_name: str, complete_var: str, shell: str) -> Path: + path_obj = Path.home() / f".config/fish/completions/{prog_name}.fish" + parent_dir: Path = path_obj.parent + parent_dir.mkdir(parents=True, exist_ok=True) + script_content = get_completion_script( + prog_name=prog_name, complete_var=complete_var, shell=shell + ) + path_obj.write_text(f"{script_content}\n") + return path_obj + + +def install_powershell(*, prog_name: str, complete_var: str, shell: str) -> Path: + subprocess.run( + [ + shell, + "-Command", + "Set-ExecutionPolicy", + "Unrestricted", + "-Scope", + "CurrentUser", + ] + ) + result = subprocess.run( + [shell, "-NoProfile", "-Command", "echo", "$profile"], + check=True, + stdout=subprocess.PIPE, + ) + if result.returncode != 0: # pragma: no cover + click.echo("Couldn't get PowerShell user profile", err=True) + raise click.exceptions.Exit(result.returncode) + path_str = "" + if isinstance(result.stdout, str): # pragma: no cover + path_str = result.stdout + if isinstance(result.stdout, bytes): + for encoding in ["windows-1252", "utf8", "cp850"]: + try: + path_str = result.stdout.decode(encoding) + break + except UnicodeDecodeError: # pragma: no cover + pass + if not path_str: # pragma: no cover + click.echo("Couldn't decode the path automatically", err=True) + raise click.exceptions.Exit(1) + path_obj = Path(path_str.strip()) + parent_dir: Path = path_obj.parent + parent_dir.mkdir(parents=True, exist_ok=True) + script_content = get_completion_script( + prog_name=prog_name, complete_var=complete_var, shell=shell + ) + with path_obj.open(mode="a") as f: + f.write(f"{script_content}\n") + return path_obj + + +def install( + shell: Optional[str] = None, + prog_name: Optional[str] = None, + complete_var: Optional[str] = None, +) -> Tuple[str, Path]: + prog_name = prog_name or click.get_current_context().find_root().info_name + assert prog_name + if complete_var is None: + complete_var = "_{}_COMPLETE".format(prog_name.replace("-", "_").upper()) + test_disable_detection = os.getenv("_TYPER_COMPLETE_TEST_DISABLE_SHELL_DETECTION") + if shell is None and shellingham is not None and not test_disable_detection: + shell, _ = shellingham.detect_shell() + if shell == "bash": + installed_path = install_bash( + prog_name=prog_name, complete_var=complete_var, shell=shell + ) + return shell, installed_path + elif shell == "zsh": + installed_path = install_zsh( + prog_name=prog_name, complete_var=complete_var, shell=shell + ) + return shell, installed_path + elif shell == "fish": + installed_path = install_fish( + prog_name=prog_name, complete_var=complete_var, shell=shell + ) + return shell, installed_path + elif shell in {"powershell", "pwsh"}: + installed_path = install_powershell( + prog_name=prog_name, complete_var=complete_var, shell=shell + ) + return shell, installed_path + else: + click.echo(f"Shell {shell} is not supported.") + raise click.exceptions.Exit(1)