Skip to main content

toasty_core/stmt/
value_set.rs

1use super::{SparseRecord, Value, ValueRecord};
2
3use hashbrown::{Equivalent, HashSet};
4use std::hash::{Hash, Hasher};
5
6/// A set of [`Value`]s.
7///
8/// Provides hash-based deduplication with well-defined semantics for every
9/// `Value` variant, including future floating-point variants (which will use
10/// bitwise comparison so `NaN == NaN` and `+0.0 != -0.0`). `Value` itself does
11/// not implement `Hash`/`Eq` because the right float policy is
12/// context-dependent; `ValueSet` picks the policy suitable for deduplication.
13#[derive(Debug, Default, Clone)]
14pub struct ValueSet {
15    inner: HashSet<HashableValue>,
16}
17
18impl ValueSet {
19    /// Creates an empty set.
20    pub fn new() -> Self {
21        Self {
22            inner: HashSet::new(),
23        }
24    }
25
26    /// Creates an empty set with capacity for at least `capacity` values.
27    pub fn with_capacity(capacity: usize) -> Self {
28        Self {
29            inner: HashSet::with_capacity(capacity),
30        }
31    }
32
33    /// Inserts a value into the set. Returns `true` if the value was not
34    /// already present.
35    pub fn insert(&mut self, value: Value) -> bool {
36        self.inner.insert(HashableValue(value))
37    }
38
39    /// Returns the number of values in the set.
40    pub fn len(&self) -> usize {
41        self.inner.len()
42    }
43
44    /// Returns `true` if the set contains no values.
45    pub fn is_empty(&self) -> bool {
46        self.inner.is_empty()
47    }
48}
49
50/// A `Value` wrapped so it can be used as a hash-table key.
51///
52/// Hash/equality semantics are "literal bits": recursive, with bitwise
53/// comparison for floats once they are added (so `NaN` hashes equal to itself
54/// and `+0.0` is distinct from `-0.0`). This is appropriate for deduplication
55/// and join-index lookup, not for SQL-semantic equality.
56#[derive(Clone, Debug)]
57pub(super) struct HashableValue(pub(super) Value);
58
59impl PartialEq for HashableValue {
60    fn eq(&self, other: &Self) -> bool {
61        value_eq(&self.0, &other.0)
62    }
63}
64
65impl Eq for HashableValue {}
66
67impl Hash for HashableValue {
68    fn hash<H: Hasher>(&self, state: &mut H) {
69        hash_value(&self.0, state);
70    }
71}
72
73/// Borrowed view over a `&[Value]`, used to query hash tables keyed by
74/// `Vec<HashableValue>` without per-lookup allocation. Implements
75/// [`Equivalent`] against the owned key type.
76///
77/// The `Hash` impl here must produce a byte stream identical to
78/// `<Vec<HashableValue> as Hash>` for equivalent contents.
79pub(super) struct HashableValueSlice<'a>(pub(super) &'a [Value]);
80
81impl Hash for HashableValueSlice<'_> {
82    fn hash<H: Hasher>(&self, state: &mut H) {
83        // Mirror `<[HashableValue] as Hash>::hash`: length prefix, then each
84        // element. `usize::hash` and the default `write_length_prefix` both
85        // resolve to `Hasher::write_usize`, so the two paths agree.
86        self.0.len().hash(state);
87        for v in self.0 {
88            hash_value(v, state);
89        }
90    }
91}
92
93impl Equivalent<Vec<HashableValue>> for HashableValueSlice<'_> {
94    fn equivalent(&self, key: &Vec<HashableValue>) -> bool {
95        self.0.len() == key.len() && self.0.iter().zip(key).all(|(v, hv)| value_eq(v, &hv.0))
96    }
97}
98
99// Each variant is spelled out rather than collapsed into a blanket `a == b`
100// fallback for two reasons:
101//
102// 1. Some variants must diverge from `PartialEq`: containers (List, Record,
103//    SparseRecord) recurse through `value_eq` so that future float variants
104//    inside them use bitwise comparison instead of `PartialEq`'s NaN-never-
105//    equal semantics. Future F32/F64 variants will themselves diverge, using
106//    `to_bits()` equality.
107// 2. Listing every variant forces a deliberate choice when a new one is
108//    added. A tuple match can't be compiler-exhaustive without a catch-all,
109//    so the real exhaustiveness check lives in `hash_value` below (single-
110//    value match, no `_` arm) — adding a variant there errors until it's
111//    handled, at which point the author is prompted to audit this function.
112pub(super) fn value_eq(a: &Value, b: &Value) -> bool {
113    use Value::*;
114    match (a, b) {
115        (Null, Null) => true,
116        (Bool(a), Bool(b)) => a == b,
117        (I8(a), I8(b)) => a == b,
118        (I16(a), I16(b)) => a == b,
119        (I32(a), I32(b)) => a == b,
120        (I64(a), I64(b)) => a == b,
121        (U8(a), U8(b)) => a == b,
122        (U16(a), U16(b)) => a == b,
123        (U32(a), U32(b)) => a == b,
124        (U64(a), U64(b)) => a == b,
125        (F32(a), F32(b)) => a.to_bits() == b.to_bits(),
126        (F64(a), F64(b)) => a.to_bits() == b.to_bits(),
127        (String(a), String(b)) => a == b,
128        (Bytes(a), Bytes(b)) => a == b,
129        (Uuid(a), Uuid(b)) => a == b,
130        (List(a), List(b)) => a.len() == b.len() && a.iter().zip(b).all(|(x, y)| value_eq(x, y)),
131        (Record(a), Record(b)) => record_eq(a, b),
132        (SparseRecord(a), SparseRecord(b)) => sparse_record_eq(a, b),
133        #[cfg(feature = "rust_decimal")]
134        (Decimal(a), Decimal(b)) => a == b,
135        #[cfg(feature = "bigdecimal")]
136        (BigDecimal(a), BigDecimal(b)) => a == b,
137        #[cfg(feature = "jiff")]
138        (Timestamp(a), Timestamp(b)) => a == b,
139        #[cfg(feature = "jiff")]
140        (Zoned(a), Zoned(b)) => a == b,
141        #[cfg(feature = "jiff")]
142        (Date(a), Date(b)) => a == b,
143        #[cfg(feature = "jiff")]
144        (Time(a), Time(b)) => a == b,
145        #[cfg(feature = "jiff")]
146        (DateTime(a), DateTime(b)) => a == b,
147        _ => false,
148    }
149}
150
151fn record_eq(a: &ValueRecord, b: &ValueRecord) -> bool {
152    a.fields.len() == b.fields.len() && a.fields.iter().zip(&b.fields).all(|(x, y)| value_eq(x, y))
153}
154
155fn sparse_record_eq(a: &SparseRecord, b: &SparseRecord) -> bool {
156    a.fields == b.fields
157        && a.values.len() == b.values.len()
158        && a.values.iter().zip(&b.values).all(|(x, y)| value_eq(x, y))
159}
160
161pub(super) fn hash_value<H: Hasher>(v: &Value, state: &mut H) {
162    // Hash the discriminant so that two variants with equal payload bits
163    // don't collide (e.g. `I32(0)` vs `U32(0)`).
164    std::mem::discriminant(v).hash(state);
165    match v {
166        Value::Null => {}
167        Value::Bool(x) => x.hash(state),
168        Value::I8(x) => x.hash(state),
169        Value::I16(x) => x.hash(state),
170        Value::I32(x) => x.hash(state),
171        Value::I64(x) => x.hash(state),
172        Value::U8(x) => x.hash(state),
173        Value::U16(x) => x.hash(state),
174        Value::U32(x) => x.hash(state),
175        Value::U64(x) => x.hash(state),
176        Value::F32(x) => x.to_bits().hash(state),
177        Value::F64(x) => x.to_bits().hash(state),
178        Value::String(x) => x.hash(state),
179        Value::Bytes(x) => x.hash(state),
180        Value::Uuid(x) => x.hash(state),
181        Value::List(items) => {
182            items.len().hash(state);
183            for it in items {
184                hash_value(it, state);
185            }
186        }
187        Value::Record(r) => {
188            r.fields.len().hash(state);
189            for v in &r.fields {
190                hash_value(v, state);
191            }
192        }
193        Value::SparseRecord(r) => {
194            r.fields.hash(state);
195            r.values.len().hash(state);
196            for v in &r.values {
197                hash_value(v, state);
198            }
199        }
200        #[cfg(feature = "rust_decimal")]
201        Value::Decimal(x) => x.hash(state),
202        #[cfg(feature = "bigdecimal")]
203        Value::BigDecimal(x) => {
204            // `bigdecimal::BigDecimal` implements `Hash`.
205            x.hash(state);
206        }
207        #[cfg(feature = "jiff")]
208        Value::Timestamp(x) => x.hash(state),
209        #[cfg(feature = "jiff")]
210        Value::Zoned(x) => x.hash(state),
211        #[cfg(feature = "jiff")]
212        Value::Date(x) => x.hash(state),
213        #[cfg(feature = "jiff")]
214        Value::Time(x) => x.hash(state),
215        #[cfg(feature = "jiff")]
216        Value::DateTime(x) => x.hash(state),
217    }
218}
219
220#[cfg(test)]
221mod tests {
222    use super::*;
223    use std::collections::hash_map::DefaultHasher;
224
225    fn hash<T: Hash + ?Sized>(v: &T) -> u64 {
226        let mut h = DefaultHasher::new();
227        v.hash(&mut h);
228        h.finish()
229    }
230
231    #[test]
232    fn slice_and_vec_hash_match() {
233        // HashableValueSlice must produce the same hash as Vec<HashableValue>
234        // for equivalent contents. If this fails, `HashIndex` lookups will
235        // miss entries even when equal.
236        let values = [Value::from(1_i64), Value::from("hello"), Value::from(true)];
237        let owned: Vec<HashableValue> = values.iter().cloned().map(HashableValue).collect();
238
239        assert_eq!(hash(&HashableValueSlice(&values)), hash(&owned));
240    }
241
242    #[test]
243    fn slice_and_vec_hash_match_empty() {
244        let values: [Value; 0] = [];
245        let owned: Vec<HashableValue> = vec![];
246        assert_eq!(hash(&HashableValueSlice(&values)), hash(&owned));
247    }
248
249    #[test]
250    fn slice_equivalent_to_vec() {
251        let values = [Value::from(1_i64), Value::from(2_i64)];
252        let owned: Vec<HashableValue> = values.iter().cloned().map(HashableValue).collect();
253        assert!(HashableValueSlice(&values).equivalent(&owned));
254    }
255
256    #[test]
257    fn value_set_dedup() {
258        let mut set = ValueSet::new();
259        assert!(set.insert(Value::from(1_i64)));
260        assert!(!set.insert(Value::from(1_i64)));
261        assert!(set.insert(Value::from(2_i64)));
262        assert_eq!(set.len(), 2);
263    }
264}