dtypes

package
v0.6.4 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Apr 13, 2025 License: Apache-2.0 Imports: 12 Imported by: 42

Documentation

Index

Constants

View Source
const (
	// INVALID (or PJRT_Buffer_Type_INVALID) is the C enum name for InvalidDType.
	INVALID = InvalidDType

	// PRED (or PJRT_Buffer_Type_PRED) is the C enum name for Bool.
	PRED = Bool

	// S8 (or PJRT_Buffer_Type_S8) is the C enum name for Int8.
	S8 = Int8

	// S16 (or PJRT_Buffer_Type_S16) is the C enum name for Int16.
	S16 = Int16

	// S32 (or PJRT_Buffer_Type_S32) is the C enum name for Int32.
	S32 = Int32

	// S64 (or PJRT_Buffer_Type_S64) is the C enum name for Int64.
	S64 = Int64

	// U8 (or PJRT_Buffer_Type_U8) is the C enum name for Uint8.
	U8 = Uint8

	// U16 (or PJRT_Buffer_Type_U16) is the C enum name for Uint16.
	U16 = Uint16

	// U32 (or PJRT_Buffer_Type_U32) is the C enum name for Uint32.
	U32 = Uint32

	// U64 (or PJRT_Buffer_Type_U64) is the C enum name for Uint64.
	U64 = Uint64

	// F16 (or PJRT_Buffer_Type_F16) is the C enum name for Float16.
	F16 = Float16

	// F32 (or PJRT_Buffer_Type_F32) is the C enum name for Float32.
	F32 = Float32

	// F64 (or PJRT_Buffer_Type_F64) is the C enum name for Float64.
	F64 = Float64

	// BF16 (or PJRT_Buffer_Type_BF16) is the C enum name for BFloat16.
	BF16 = BFloat16

	// C64 (or PJRT_Buffer_Type_C64) is the C enum name for Complex64.
	C64 = Complex64

	// C128 (or PJRT_Buffer_Type_C128) is the C enum name for Complex128.
	C128 = Complex128
)

Aliases from PJRT C API.

Variables

View Source
var MapOfNames = map[string]DType{
	"InvalidDType":  InvalidDType,
	"INVALID":       InvalidDType,
	"Bool":          Bool,
	"PRED":          Bool,
	"Int8":          Int8,
	"S8":            Int8,
	"Int16":         Int16,
	"S16":           Int16,
	"Int32":         Int32,
	"S32":           Int32,
	"Int64":         Int64,
	"S64":           Int64,
	"Uint8":         Uint8,
	"U8":            Uint8,
	"Uint16":        Uint16,
	"U16":           Uint16,
	"Uint32":        Uint32,
	"U32":           Uint32,
	"Uint64":        Uint64,
	"U64":           Uint64,
	"Float16":       Float16,
	"F16":           Float16,
	"Float32":       Float32,
	"F32":           Float32,
	"Float64":       Float64,
	"F64":           Float64,
	"BFloat16":      BFloat16,
	"BF16":          BFloat16,
	"Complex64":     Complex64,
	"C64":           Complex64,
	"Complex128":    Complex128,
	"C128":          Complex128,
	"F8E5M2":        F8E5M2,
	"F8E4M3FN":      F8E4M3FN,
	"F8E4M3B11FNUZ": F8E4M3B11FNUZ,
	"F8E5M2FNUZ":    F8E5M2FNUZ,
	"F8E4M3FNUZ":    F8E4M3FNUZ,
	"S4":            S4,
	"U4":            U4,
	"TOKEN":         TOKEN,
	"S2":            S2,
	"U2":            U2,
	"F8E4M3":        F8E4M3,
	"F8E3M4":        F8E3M4,
	"F8E8M0FNU":     F8E8M0FNU,
	"F4E2M1FN":      F4E2M1FN,
}

MapOfNames to their dtypes. It includes also aliases to the various dtypes. It is also later initialized to include the lower-case version of the names.

Functions

func DTypeStrings added in v0.2.1

func DTypeStrings() []string

DTypeStrings returns a slice of all String values of the enum

Types

type DType

type DType int32

DType is an enum represents the data type of a buffer or a scalar. These are all the types supported by XLA/PJRT.

The names come from the C/C++ constants, so they are not Go idiomatic. The package provides some aliases.

It is unfortunate, but the data types enums used in XLA/PJRT (which DType is modeled after) and in C++ XlaBuilder (and other parts of XLA) don't match. The gopjrt project uses the PJRT enum everywhere, and makes the conversions when needed to call C++ code (see DType.PrimitiveType and FromPrimitiveType for conversion).

const (
	// InvalidDType is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h (as PJRT_Buffer_Type_INVALID).
	// Invalid primitive type to serve as default.
	InvalidDType DType = 0

	// Bool is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h (as PJRT_Buffer_Type_PRED).
	// Predicates are two-state booleans.
	Bool DType = 1

	// Int8 is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h (as PJRT_Buffer_Type_S8).
	// Signed integral values of fixed width.
	Int8 DType = 2

	// Int16 is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h (as PJRT_Buffer_Type_S16).
	Int16 DType = 3

	// Int32 is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h (as PJRT_Buffer_Type_S32).
	Int32 DType = 4

	// Int64 is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h (as PJRT_Buffer_Type_S64).
	Int64 DType = 5

	// Uint8 is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h (as PJRT_Buffer_Type_U8).
	// Unsigned integral values of fixed width.
	Uint8 DType = 6

	// Uint16 is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h (as PJRT_Buffer_Type_U16).
	Uint16 DType = 7

	// Uint32 is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h (as PJRT_Buffer_Type_U32).
	Uint32 DType = 8

	// Uint64 is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h (as PJRT_Buffer_Type_U64).
	Uint64 DType = 9

	// Float16 is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h (as PJRT_Buffer_Type_F16).
	// Floating-point values of fixed width.
	Float16 DType = 10

	// Float32 is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h (as PJRT_Buffer_Type_F32).
	Float32 DType = 11

	// Float64 is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h (as PJRT_Buffer_Type_F64).
	Float64 DType = 12

	// BFloat16 is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h (as PJRT_Buffer_Type_BF16).
	// Truncated 16 bit floating-point format. This is similar to IEEE's 16 bit
	// floating-point format, but uses 1 bit for the sign, 8 bits for the exponent
	// and 7 bits for the mantissa.
	BFloat16 DType = 13

	// Complex64 is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h (as PJRT_Buffer_Type_C64).
	// Complex values of fixed width.
	//
	// Paired F32 (real, imag), as in std::complex<float>.
	Complex64 DType = 14

	// Complex128 is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h (as PJRT_Buffer_Type_C128).
	// Paired F64 (real, imag), as in std::complex<double>.
	Complex128 DType = 15

	// F8E5M2 is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h (as PJRT_Buffer_Type_F8E5M2).
	// Truncated 8 bit floating-point formats.
	F8E5M2 DType = 16

	// F8E4M3FN is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h (as PJRT_Buffer_Type_F8E4M3FN).
	F8E4M3FN DType = 17

	// F8E4M3B11FNUZ is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h (as PJRT_Buffer_Type_F8E4M3B11FNUZ).
	F8E4M3B11FNUZ DType = 18

	// F8E5M2FNUZ is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h (as PJRT_Buffer_Type_F8E5M2FNUZ).
	F8E5M2FNUZ DType = 19

	// F8E4M3FNUZ is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h (as PJRT_Buffer_Type_F8E4M3FNUZ).
	F8E4M3FNUZ DType = 20

	// S4 is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h (as PJRT_Buffer_Type_S4).
	// 4-bit integer types
	S4 DType = 21

	// U4 is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h (as PJRT_Buffer_Type_U4).
	U4 DType = 22

	// TOKEN is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h (as PJRT_Buffer_Type_TOKEN).
	TOKEN DType = 23

	// S2 is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h (as PJRT_Buffer_Type_S2).
	// 2-bit integer types
	S2 DType = 24

	// U2 is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h (as PJRT_Buffer_Type_U2).
	U2 DType = 25

	// F8E4M3 is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h (as PJRT_Buffer_Type_F8E4M3).
	// More truncated 8 bit floating-point formats.
	F8E4M3 DType = 26

	// F8E3M4 is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h (as PJRT_Buffer_Type_F8E3M4).
	F8E3M4 DType = 27

	// F8E8M0FNU is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h (as PJRT_Buffer_Type_F8E8M0FNU).
	F8E8M0FNU DType = 28

	// F4E2M1FN is a 1:1 mapping of the corresponding C enum value defined in pjrt_c_api.h (as PJRT_Buffer_Type_F4E2M1FN).
	// 4-bit MX floating-point format.
	F4E2M1FN DType = 29
)

func DTypeString added in v0.2.1

func DTypeString(s string) (DType, error)

DTypeString retrieves an enum value from the enum constants string name. Throws an error if the param is not part of the enum.

func DTypeValues added in v0.2.1

func DTypeValues() []DType

DTypeValues returns all values of the enum

func FromAny

func FromAny(value any) DType

FromAny introspects the underlying type of any and return the corresponding DType. Non-scalar types, or not supported types returns a InvalidType.

func FromGenericsType added in v0.2.0

func FromGenericsType[T Supported]() DType

FromGenericsType returns the DType enum for the given type that this package knows about.

func FromGoType

func FromGoType(t reflect.Type) DType

FromGoType returns the DType for the given reflect.Type. It panics for unknown DType values.

func FromPrimitiveType

func FromPrimitiveType(primitiveType xla_data.PrimitiveType) DType

FromPrimitiveType returns the equivalent DType. For internal use only.

It is unfortunate, but the data types enums used in PJRT (which DType is modeled after) and C++ XlaBuilder (and other parts of XLA) don't match.

func (DType) GoStr

func (dtype DType) GoStr() string

GoStr converts dtype to the corresponding Go type and convert that to string. Notice the names are different from the Dtype (so `Int64` dtype is simply `int` in Go).

func (DType) GoType

func (dtype DType) GoType() reflect.Type

GoType returns the Go `reflect.Type` corresponding to the tensor DType.

func (DType) HighestValue

func (dtype DType) HighestValue() any

HighestValue for dtype converted to the corresponding Go type. For float values it will return infinite. There is no lowest value for complex numbers, since they are not ordered.

func (DType) IsADType added in v0.2.1

func (i DType) IsADType() bool

IsADType returns "true" if the value is listed in the enum definition. "false" otherwise

func (DType) IsComplex added in v0.2.0

func (dtype DType) IsComplex() bool

IsComplex returns whether dtype is a supported complex number type.

func (DType) IsFloat added in v0.2.0

func (dtype DType) IsFloat() bool

IsFloat returns whether dtype is a supported float -- float types not yet supported will return false. It returns false for complex numbers.

func (DType) IsFloat16 added in v0.2.0

func (dtype DType) IsFloat16() bool

IsFloat16 returns whether dtype is a supported float with 16 bits: Float16 or BFloat16.

func (DType) IsInt added in v0.2.0

func (dtype DType) IsInt() bool

IsInt returns whether dtype is a supported integer type -- float types not yet supported will return false.

func (DType) IsSupported added in v0.2.0

func (dtype DType) IsSupported() bool

IsSupported returns whether dtype is supported by `gopjrt`.

func (DType) IsUnsigned added in v0.4.6

func (dtype DType) IsUnsigned() bool

IsUnsigned returns whether dtype is one of the unsigned (only int for now) types.

func (DType) LowestValue

func (dtype DType) LowestValue() any

LowestValue for dtype converted to the corresponding Go type. For float values it will return negative infinite. There is no lowest value for complex numbers, since they are not ordered.

func (DType) MarshalJSON added in v0.2.1

func (i DType) MarshalJSON() ([]byte, error)

MarshalJSON implements the json.Marshaler interface for DType

func (DType) MarshalText added in v0.2.1

func (i DType) MarshalText() ([]byte, error)

MarshalText implements the encoding.TextMarshaler interface for DType

func (DType) MarshalYAML added in v0.2.1

func (i DType) MarshalYAML() (interface{}, error)

MarshalYAML implements a YAML Marshaler for DType

func (DType) Memory added in v0.2.0

func (dtype DType) Memory() uintptr

Memory returns the number of bytes for the given DType. It's an alias to Size, converted to uintptr.

func (DType) PrimitiveType

func (dtype DType) PrimitiveType() xla_data.PrimitiveType

PrimitiveType returns the DType equivalent used in C++ XlaBuilder. For internal use only.

It is unfortunate, but the data types enums used in PJRT (which DType is modeled after) and C++ XlaBuilder (and other parts of XLA) don't match.

func (DType) RealDType added in v0.2.0

func (dtype DType) RealDType() DType

RealDType returns the real component of complex dtypes. For float dtypes, it returns itself.

It returns InvalidDType for other non-(complex or float) dtypes.

func (DType) Size

func (dtype DType) Size() int

Size returns the number of bytes for the given DType. If the size is < 1 (like a 4-bits quantity), consider the SizeForDimensions method.

func (DType) SizeForDimensions added in v0.5.0

func (dtype DType) SizeForDimensions(dimensions ...int) int

SizeForDimensions returns the size in bytes used for the given dimensions. This is a safer method than Size in case the dtype uses an underlying size that is not multiple of 8 bits.

It works also for scalar (one element), where dimensions list is empty.

func (DType) SmallestNonZeroValueForDType

func (dtype DType) SmallestNonZeroValueForDType() any

SmallestNonZeroValueForDType is the smallest non-zero value dtypes. Only useful for float types. The return value is converted to the corresponding Go type. There is no smallest non-zero value for complex numbers, since they are not ordered.

func (DType) String

func (i DType) String() string

func (*DType) UnmarshalJSON added in v0.2.1

func (i *DType) UnmarshalJSON(data []byte) error

UnmarshalJSON implements the json.Unmarshaler interface for DType

func (*DType) UnmarshalText added in v0.2.1

func (i *DType) UnmarshalText(text []byte) error

UnmarshalText implements the encoding.TextUnmarshaler interface for DType

func (*DType) UnmarshalYAML added in v0.2.1

func (i *DType) UnmarshalYAML(unmarshal func(interface{}) error) error

UnmarshalYAML implements a YAML Unmarshaler for DType

func (DType) Values added in v0.2.1

func (DType) Values() []string

type GoFloat added in v0.2.0

type GoFloat interface {
	float32 | float64
}

GoFloat represent a continuous Go numeric type, supported by GoMLX. It doesn't include complex numbers.

type Number added in v0.2.0

type Number interface {
	float32 | float64 | int | int8 | int16 | int32 | int64 | uint8 | uint16 | uint32 | uint64 | complex64 | complex128
}

Number represents the Go numeric types that are supported by graph package. Used as traits for generics.

It includes complex numbers. It doesn't include float16.Float16 or bfloat16.BFloat16 because they are not a native number type.

type NumberNotComplex added in v0.2.0

type NumberNotComplex interface {
	float32 | float64 | int | int8 | int16 | int32 | int64 | uint8 | uint16 | uint32 | uint64
}

NumberNotComplex represents the Go numeric types that are supported by graph package except the complex numbers. Used as a Generics constraint.

It doesn't include float16.Float16 (not a native number type). See also Number.

type Supported

type Supported interface {
	bool | float16.Float16 | bfloat16.BFloat16 |
		float32 | float64 | int | int8 | int16 | int32 | int64 | uint8 | uint16 | uint32 | uint64 |
		complex64 | complex128
}

Supported lists the Go types that `gopjrt` knows how to convert -- there are more types that can be manually converted. Used as traits for generics.

Notice Go's `int` type is not portable, since it may translate to dtypes Int32 or Int64 depending on the platform.

Directories

Path Synopsis
Package bfloat16 is a trivial implementation for the bfloat16 type, based on https://github.com/x448/float16 and the pending issue in https://github.com/x448/float16/issues/22
Package bfloat16 is a trivial implementation for the bfloat16 type, based on https://github.com/x448/float16 and the pending issue in https://github.com/x448/float16/issues/22

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL