1use std::cell::RefCell;
4use std::marker::PhantomData;
5use std::rc::Rc;
6
7use proc_macro2::Span;
8use quote::quote;
9use stageleft::runtime_support::{FreeVariableWithContextWithProps, QuoteTokens};
10
11use crate::compile::ir::{HydroNode, SharedNode};
12use crate::location::Location;
13
14pub struct SingletonRef<'a, 'slf, T, L, const IS_MUT: bool = false> {
22 pub(crate) ir_node: &'slf RefCell<HydroNode>,
24 _phantom: PhantomData<(&'a T, L)>,
25}
26pub type SingletonMut<'a, 'slf, T, L> = SingletonRef<'a, 'slf, T, L, true>;
28
29impl<'slf, T, L, const IS_MUT: bool> SingletonRef<'_, 'slf, T, L, IS_MUT> {
30 pub(crate) fn new(ir_node: &'slf RefCell<HydroNode>) -> Self {
32 Self {
33 ir_node,
34 _phantom: PhantomData,
35 }
36 }
37
38 pub fn as_ref(&self) -> SingletonRef<'_, 'slf, T, L, false> {
40 SingletonRef {
41 ir_node: self.ir_node,
42 _phantom: PhantomData,
43 }
44 }
45
46 pub fn as_mut(&self) -> SingletonRef<'_, 'slf, T, L, true> {
48 SingletonRef {
49 ir_node: self.ir_node,
50 _phantom: PhantomData,
51 }
52 }
53}
54
55impl<T, L, const IS_MUT: bool> Copy for SingletonRef<'_, '_, T, L, IS_MUT> {}
56impl<T, L, const IS_MUT: bool> Clone for SingletonRef<'_, '_, T, L, IS_MUT> {
57 fn clone(&self) -> Self {
58 *self
59 }
60}
61
62thread_local! {
66 static SINGLETON_REFS: RefCell<Option<Vec<(HydroNode, bool)>>> = const { RefCell::new(None) };
67}
68
69pub(crate) fn singleton_ref_ident(index: usize) -> syn::Ident {
71 syn::Ident::new(
72 &format!("__hydro_singleton_ref_{}", index),
73 Span::call_site(),
74 )
75}
76
77pub fn with_singleton_capture(
81 f: impl FnOnce() -> crate::compile::ir::DebugExpr,
82) -> crate::compile::ir::ClosureExpr {
83 SINGLETON_REFS.with(|cell| {
84 let prev = cell.borrow_mut().replace(Vec::new());
85 assert!(
86 prev.is_none(),
87 "nested singleton capture scopes are not supported"
88 );
89 });
90 let expr = (f)();
91 let singleton_refs = SINGLETON_REFS.with(|cell| cell.borrow_mut().take().unwrap());
92 crate::compile::ir::ClosureExpr::new(expr, singleton_refs)
93}
94
95impl<'a, 'slf, T: 'a, L, const IS_MUT: bool> SingletonRef<'a, 'slf, T, L, IS_MUT>
96where
97 L: Location<'a>,
98{
99 fn to_tokens_helper(self, _ctx: &L) -> (QuoteTokens, ()) {
100 let ident = SINGLETON_REFS.with(|cell| {
101 let mut guard = cell.borrow_mut();
102 let refs = guard.as_mut().expect(
103 "SingletonRef used inside q!() but no singleton capture scope is active. \
104 This is a bug — singleton capture should be set up by the operator that uses q!().",
105 );
106
107 let index = refs.len();
108 let ident = singleton_ref_ident(index);
109
110 let metadata = self.ir_node.borrow().metadata().clone();
111
112 if !matches!(&*self.ir_node.borrow(), HydroNode::Singleton { .. }) {
115 let orig = self.ir_node.replace(HydroNode::Placeholder);
116 *self.ir_node.borrow_mut() = HydroNode::Singleton {
117 inner: SharedNode(Rc::new(RefCell::new(orig))),
118 metadata: metadata.clone(),
119 };
120 }
121
122 let borrow: std::cell::Ref<'_, HydroNode> = self.ir_node.borrow();
123 let HydroNode::Singleton { inner, .. } = &*borrow else {
124 unreachable!()
125 };
126
127 refs.push((
128 HydroNode::Singleton {
129 inner: SharedNode(Rc::clone(&inner.0)),
130 metadata,
131 },
132 IS_MUT,
133 ));
134
135 ident
136 });
137
138 (
139 QuoteTokens {
140 prelude: None,
141 expr: Some(quote!(#ident)),
142 },
143 (),
144 )
145 }
146}
147
148impl<'a, 'slf, T: 'a, L> FreeVariableWithContextWithProps<L, ()> for SingletonRef<'a, 'slf, T, L>
149where
150 L: Location<'a>,
151{
152 type O = &'a T;
153
154 fn to_tokens(self, ctx: &L) -> (QuoteTokens, ()) {
155 self.to_tokens_helper(ctx)
156 }
157}
158
159impl<'a, 'slf, T: 'a, L> FreeVariableWithContextWithProps<L, ()> for SingletonMut<'a, 'slf, T, L>
160where
161 L: Location<'a>,
162{
163 type O = &'a mut T;
164
165 fn to_tokens(self, ctx: &L) -> (QuoteTokens, ()) {
166 self.to_tokens_helper(ctx)
167 }
168}
169
170#[cfg(test)]
171#[cfg(feature = "build")]
172mod tests {
173 use stageleft::q;
174
175 use crate::compile::builder::FlowBuilder;
176 use crate::location::Location;
177
178 struct P1 {}
179
180 #[test]
183 fn singleton_by_ref_compiles() {
184 let mut flow = FlowBuilder::new();
185 let node = flow.process::<P1>();
186
187 let my_count = node
188 .source_iter(q!(0..5i32))
189 .fold(q!(|| 0i32), q!(|acc: &mut i32, x| *acc += x));
190 let count_ref = my_count.by_ref();
191
192 node.source_iter(q!(1..=3i32))
193 .map(q!(|x| x + *count_ref))
194 .for_each(q!(|_| {}));
195
196 my_count.into_stream().for_each(q!(|_| {}));
198
199 let _built = flow.finalize();
201 }
202
203 #[test]
205 fn singleton_by_ref_non_copy() {
206 let mut flow = FlowBuilder::new();
207 let node = flow.process::<P1>();
208
209 let my_vec = node.source_iter(q!(0..5i32)).fold(
210 q!(|| Vec::<i32>::new()),
211 q!(|acc: &mut Vec<i32>, x| acc.push(x)),
212 );
213 let vec_ref = my_vec.by_ref();
214
215 node.source_iter(q!(1..=3i32))
216 .map(q!(|x| x + vec_ref.len() as i32))
217 .for_each(q!(|_| {}));
218
219 my_vec.into_stream().for_each(q!(|_| {}));
221
222 let _built = flow.finalize();
223 }
224
225 #[test]
227 fn singleton_by_ref_filter() {
228 let mut flow = FlowBuilder::new();
229 let node = flow.process::<P1>();
230
231 let threshold = node
232 .source_iter(q!(0..5i32))
233 .fold(q!(|| 0i32), q!(|acc: &mut i32, x| *acc += x));
234 let threshold_ref = threshold.by_ref();
235
236 node.source_iter(q!(1..=10i32))
237 .filter(q!(|x| *x > *threshold_ref))
238 .for_each(q!(|_| {}));
239
240 threshold.into_stream().for_each(q!(|_| {}));
241 let _built = flow.finalize();
242 }
243
244 #[test]
246 fn singleton_by_ref_flat_map() {
247 let mut flow = FlowBuilder::new();
248 let node = flow.process::<P1>();
249
250 let count = node
251 .source_iter(q!(0..3i32))
252 .fold(q!(|| 0i32), q!(|acc: &mut i32, _| *acc += 1));
253 let count_ref = count.by_ref();
254
255 node.source_iter(q!(1..=2i32))
256 .flat_map_ordered(q!(|x| (0..*count_ref).map(move |i| x + i)))
257 .for_each(q!(|_| {}));
258
259 count.into_stream().for_each(q!(|_| {}));
260 let _built = flow.finalize();
261 }
262
263 #[test]
265 fn singleton_by_ref_inspect() {
266 let mut flow = FlowBuilder::new();
267 let node = flow.process::<P1>();
268
269 let count = node
270 .source_iter(q!(0..5i32))
271 .fold(q!(|| 0i32), q!(|acc: &mut i32, _| *acc += 1));
272 let count_ref = count.by_ref();
273
274 node.source_iter(q!(1..=3i32))
275 .inspect(q!(|x| println!("count={}, x={}", *count_ref, x)))
276 .for_each(q!(|_| {}));
277
278 count.into_stream().for_each(q!(|_| {}));
279 let _built = flow.finalize();
280 }
281
282 #[test]
284 fn singleton_by_ref_partition() {
285 let mut flow = FlowBuilder::new();
286 let node = flow.process::<P1>();
287
288 let threshold = node
289 .source_iter(q!(0..5i32))
290 .fold(q!(|| 0i32), q!(|acc: &mut i32, x| *acc += x));
291 let threshold_ref = threshold.by_ref();
292
293 let (above, below) = node
294 .source_iter(q!(1..=10i32))
295 .partition(q!(|x| *x > *threshold_ref));
296
297 above.for_each(q!(|_| {}));
298 below.for_each(q!(|_| {}));
299 threshold.into_stream().for_each(q!(|_| {}));
300 let _built = flow.finalize();
301 }
302
303 #[test]
309 fn singleton_by_ref_partition_with_downstream_ops() {
310 let mut flow = FlowBuilder::new();
311 let node = flow.process::<P1>();
312
313 let threshold = node
314 .source_iter(q!(0..5i32))
315 .fold(q!(|| 0i32), q!(|acc: &mut i32, x| *acc += x));
316 let threshold_ref = threshold.by_ref();
317
318 let (above, below) = node
319 .source_iter(q!(1..=10i32))
320 .partition(q!(|x| *x > *threshold_ref));
321
322 above.map(q!(|x| x * 2)).for_each(q!(|_| {}));
324 below.map(q!(|x| x + 100)).for_each(q!(|_| {}));
325 threshold.into_stream().for_each(q!(|_| {}));
326 let _built = flow.finalize();
327 }
328}