1use core::borrow::Borrow;
2use core::cmp::Ordering;
3use core::ops::{Bound, RangeBounds};
4
5use SearchBound::*;
6use SearchResult::*;
7
8use super::node::ForceResult::*;
9use super::node::{Handle, NodeRef, marker};
10
11pub(super) enum SearchBound<T> {
12    Included(T),
14    Excluded(T),
16    AllIncluded,
18    AllExcluded,
20}
21
22impl<T> SearchBound<T> {
23    pub(super) fn from_range(range_bound: Bound<T>) -> Self {
24        match range_bound {
25            Bound::Included(t) => Included(t),
26            Bound::Excluded(t) => Excluded(t),
27            Bound::Unbounded => AllIncluded,
28        }
29    }
30}
31
32pub(super) enum SearchResult<BorrowType, K, V, FoundType, GoDownType> {
33    Found(Handle<NodeRef<BorrowType, K, V, FoundType>, marker::KV>),
34    GoDown(Handle<NodeRef<BorrowType, K, V, GoDownType>, marker::Edge>),
35}
36
37pub(super) enum IndexResult {
38    KV(usize),
39    Edge(usize),
40}
41
42impl<BorrowType: marker::BorrowType, K, V> NodeRef<BorrowType, K, V, marker::LeafOrInternal> {
43    pub(super) fn search_tree<Q: ?Sized>(
50        mut self,
51        key: &Q,
52    ) -> SearchResult<BorrowType, K, V, marker::LeafOrInternal, marker::Leaf>
53    where
54        Q: Ord,
55        K: Borrow<Q>,
56    {
57        loop {
58            self = match self.search_node(key) {
59                Found(handle) => return Found(handle),
60                GoDown(handle) => match handle.force() {
61                    Leaf(leaf) => return GoDown(leaf),
62                    Internal(internal) => internal.descend(),
63                },
64            }
65        }
66    }
67
68    pub(super) fn search_tree_for_bifurcation<'r, Q: ?Sized, R>(
84        mut self,
85        range: &'r R,
86    ) -> Result<
87        (
88            NodeRef<BorrowType, K, V, marker::LeafOrInternal>,
89            usize,
90            usize,
91            SearchBound<&'r Q>,
92            SearchBound<&'r Q>,
93        ),
94        Handle<NodeRef<BorrowType, K, V, marker::Leaf>, marker::Edge>,
95    >
96    where
97        Q: Ord,
98        K: Borrow<Q>,
99        R: RangeBounds<Q>,
100    {
101        let is_set = <V as super::set_val::IsSetVal>::is_set_val();
103
104        let (start, end) = (range.start_bound(), range.end_bound());
107        match (start, end) {
108            (Bound::Excluded(s), Bound::Excluded(e)) if s == e => {
109                if is_set {
110                    panic!("range start and end are equal and excluded in BTreeSet")
111                } else {
112                    panic!("range start and end are equal and excluded in BTreeMap")
113                }
114            }
115            (Bound::Included(s) | Bound::Excluded(s), Bound::Included(e) | Bound::Excluded(e))
116                if s > e =>
117            {
118                if is_set {
119                    panic!("range start is greater than range end in BTreeSet")
120                } else {
121                    panic!("range start is greater than range end in BTreeMap")
122                }
123            }
124            _ => {}
125        }
126        let mut lower_bound = SearchBound::from_range(start);
127        let mut upper_bound = SearchBound::from_range(end);
128        loop {
129            let (lower_edge_idx, lower_child_bound) = self.find_lower_bound_index(lower_bound);
130            let (upper_edge_idx, upper_child_bound) =
131                unsafe { self.find_upper_bound_index(upper_bound, lower_edge_idx) };
132            if lower_edge_idx < upper_edge_idx {
133                return Ok((
134                    self,
135                    lower_edge_idx,
136                    upper_edge_idx,
137                    lower_child_bound,
138                    upper_child_bound,
139                ));
140            }
141            debug_assert_eq!(lower_edge_idx, upper_edge_idx);
142            let common_edge = unsafe { Handle::new_edge(self, lower_edge_idx) };
143            match common_edge.force() {
144                Leaf(common_edge) => return Err(common_edge),
145                Internal(common_edge) => {
146                    self = common_edge.descend();
147                    lower_bound = lower_child_bound;
148                    upper_bound = upper_child_bound;
149                }
150            }
151        }
152    }
153
154    pub(super) fn find_lower_bound_edge<'r, Q>(
160        self,
161        bound: SearchBound<&'r Q>,
162    ) -> (Handle<Self, marker::Edge>, SearchBound<&'r Q>)
163    where
164        Q: ?Sized + Ord,
165        K: Borrow<Q>,
166    {
167        let (edge_idx, bound) = self.find_lower_bound_index(bound);
168        let edge = unsafe { Handle::new_edge(self, edge_idx) };
169        (edge, bound)
170    }
171
172    pub(super) fn find_upper_bound_edge<'r, Q>(
174        self,
175        bound: SearchBound<&'r Q>,
176    ) -> (Handle<Self, marker::Edge>, SearchBound<&'r Q>)
177    where
178        Q: ?Sized + Ord,
179        K: Borrow<Q>,
180    {
181        let (edge_idx, bound) = unsafe { self.find_upper_bound_index(bound, 0) };
182        let edge = unsafe { Handle::new_edge(self, edge_idx) };
183        (edge, bound)
184    }
185}
186
187impl<BorrowType, K, V, Type> NodeRef<BorrowType, K, V, Type> {
188    pub(super) fn search_node<Q: ?Sized>(
196        self,
197        key: &Q,
198    ) -> SearchResult<BorrowType, K, V, Type, Type>
199    where
200        Q: Ord,
201        K: Borrow<Q>,
202    {
203        match unsafe { self.find_key_index(key, 0) } {
204            IndexResult::KV(idx) => Found(unsafe { Handle::new_kv(self, idx) }),
205            IndexResult::Edge(idx) => GoDown(unsafe { Handle::new_edge(self, idx) }),
206        }
207    }
208
209    unsafe fn find_key_index<Q: ?Sized>(&self, key: &Q, start_index: usize) -> IndexResult
218    where
219        Q: Ord,
220        K: Borrow<Q>,
221    {
222        let node = self.reborrow();
223        let keys = node.keys();
224        debug_assert!(start_index <= keys.len());
225        for (offset, k) in unsafe { keys.get_unchecked(start_index..) }.iter().enumerate() {
226            match key.cmp(k.borrow()) {
227                Ordering::Greater => {}
228                Ordering::Equal => return IndexResult::KV(start_index + offset),
229                Ordering::Less => return IndexResult::Edge(start_index + offset),
230            }
231        }
232        IndexResult::Edge(keys.len())
233    }
234
235    fn find_lower_bound_index<'r, Q>(
241        &self,
242        bound: SearchBound<&'r Q>,
243    ) -> (usize, SearchBound<&'r Q>)
244    where
245        Q: ?Sized + Ord,
246        K: Borrow<Q>,
247    {
248        match bound {
249            Included(key) => match unsafe { self.find_key_index(key, 0) } {
250                IndexResult::KV(idx) => (idx, AllExcluded),
251                IndexResult::Edge(idx) => (idx, bound),
252            },
253            Excluded(key) => match unsafe { self.find_key_index(key, 0) } {
254                IndexResult::KV(idx) => (idx + 1, AllIncluded),
255                IndexResult::Edge(idx) => (idx, bound),
256            },
257            AllIncluded => (0, AllIncluded),
258            AllExcluded => (self.len(), AllExcluded),
259        }
260    }
261
262    unsafe fn find_upper_bound_index<'r, Q>(
268        &self,
269        bound: SearchBound<&'r Q>,
270        start_index: usize,
271    ) -> (usize, SearchBound<&'r Q>)
272    where
273        Q: ?Sized + Ord,
274        K: Borrow<Q>,
275    {
276        match bound {
277            Included(key) => match unsafe { self.find_key_index(key, start_index) } {
278                IndexResult::KV(idx) => (idx + 1, AllExcluded),
279                IndexResult::Edge(idx) => (idx, bound),
280            },
281            Excluded(key) => match unsafe { self.find_key_index(key, start_index) } {
282                IndexResult::KV(idx) => (idx, AllIncluded),
283                IndexResult::Edge(idx) => (idx, bound),
284            },
285            AllIncluded => (self.len(), AllIncluded),
286            AllExcluded => (start_index, AllExcluded),
287        }
288    }
289}