roup/ir/
builder.rs

1//! Builder pattern for constructing DirectiveIR
2//!
3//! This module provides a fluent API for building OpenMP directives
4//! with compile-time validation and convenience methods.
5//!
6//! ## Learning Objectives
7//!
8//! - **Builder pattern**: Fluent API for object construction
9//! - **Type-state pattern**: Compile-time validation of construction
10//! - **Method chaining**: Ergonomic API design
11//! - **Smart defaults**: Sensible default values
12//!
13//! ## Example
14//!
15//! ```
16//! use roup::ir::{DirectiveBuilder, Language, SourceLocation};
17//!
18//! // Build a parallel directive with clauses
19//! let directive = DirectiveBuilder::parallel()
20//!     .default_shared()
21//!     .private(&["x", "y"])
22//!     .num_threads(4)
23//!     .build(SourceLocation::start(), Language::C);
24//!
25//! assert!(directive.kind().is_parallel());
26//! assert_eq!(directive.clauses().len(), 3);
27//! ```
28
29use super::{
30    ClauseData, ClauseItem, DefaultKind, DependType, DirectiveIR, DirectiveKind, Expression,
31    Identifier, Language, MapType, ProcBind, ReductionOperator, ScheduleKind, ScheduleModifier,
32    SourceLocation,
33};
34
35/// Builder for constructing DirectiveIR with a fluent API
36pub struct DirectiveBuilder {
37    kind: DirectiveKind,
38    name: String,
39    clauses: Vec<ClauseData>,
40}
41
42impl<'a> DirectiveBuilder {
43    /// Create a new builder for a parallel directive
44    ///
45    /// ## Example
46    ///
47    /// ```
48    /// use roup::ir::{DirectiveBuilder, Language, SourceLocation};
49    ///
50    /// let directive = DirectiveBuilder::parallel()
51    ///     .default_shared()
52    ///     .build(SourceLocation::start(), Language::C);
53    ///
54    /// assert!(directive.kind().is_parallel());
55    /// ```
56    pub fn parallel() -> Self {
57        Self {
58            kind: DirectiveKind::Parallel,
59            name: "parallel".to_string(),
60            clauses: Vec::new(),
61        }
62    }
63
64    /// Create a new builder for a parallel for directive
65    pub fn parallel_for() -> Self {
66        Self {
67            kind: DirectiveKind::ParallelFor,
68            name: "parallel for".to_string(),
69            clauses: Vec::new(),
70        }
71    }
72
73    /// Create a new builder for a for directive
74    pub fn for_loop() -> Self {
75        Self {
76            kind: DirectiveKind::For,
77            name: "for".to_string(),
78            clauses: Vec::new(),
79        }
80    }
81
82    /// Create a new builder for a task directive
83    pub fn task() -> Self {
84        Self {
85            kind: DirectiveKind::Task,
86            name: "task".to_string(),
87            clauses: Vec::new(),
88        }
89    }
90
91    /// Create a new builder for a target directive
92    pub fn target() -> Self {
93        Self {
94            kind: DirectiveKind::Target,
95            name: "target".to_string(),
96            clauses: Vec::new(),
97        }
98    }
99
100    /// Create a new builder for a teams directive
101    pub fn teams() -> Self {
102        Self {
103            kind: DirectiveKind::Teams,
104            name: "teams".to_string(),
105            clauses: Vec::new(),
106        }
107    }
108
109    /// Create a new builder for any directive kind
110    pub fn new(kind: DirectiveKind) -> Self {
111        let name = format!("{kind:?}").to_lowercase();
112        Self {
113            kind,
114            name,
115            clauses: Vec::new(),
116        }
117    }
118
119    // ========================================================================
120    // Clause builders
121    // ========================================================================
122
123    /// Add a default(shared) clause
124    pub fn default_shared(mut self) -> Self {
125        self.clauses.push(ClauseData::Default(DefaultKind::Shared));
126        self
127    }
128
129    /// Add a default(none) clause
130    pub fn default_none(mut self) -> Self {
131        self.clauses.push(ClauseData::Default(DefaultKind::None));
132        self
133    }
134
135    /// Add a default clause with specified kind
136    pub fn default(mut self, kind: DefaultKind) -> Self {
137        self.clauses.push(ClauseData::Default(kind));
138        self
139    }
140
141    /// Add a private clause
142    ///
143    /// ## Example
144    ///
145    /// ```
146    /// use roup::ir::{DirectiveBuilder, Language, SourceLocation};
147    ///
148    /// let directive = DirectiveBuilder::parallel()
149    ///     .private(&["x", "y", "z"])
150    ///     .build(SourceLocation::start(), Language::C);
151    /// ```
152    pub fn private(mut self, vars: &[&'a str]) -> Self {
153        let items = vars
154            .iter()
155            .map(|&v| ClauseItem::Identifier(Identifier::new(v)))
156            .collect();
157        self.clauses.push(ClauseData::Private { items });
158        self
159    }
160
161    /// Add a firstprivate clause
162    pub fn firstprivate(mut self, vars: &[&'a str]) -> Self {
163        let items = vars
164            .iter()
165            .map(|&v| ClauseItem::Identifier(Identifier::new(v)))
166            .collect();
167        self.clauses.push(ClauseData::Firstprivate { items });
168        self
169    }
170
171    /// Add a shared clause
172    pub fn shared(mut self, vars: &[&'a str]) -> Self {
173        let items = vars
174            .iter()
175            .map(|&v| ClauseItem::Identifier(Identifier::new(v)))
176            .collect();
177        self.clauses.push(ClauseData::Shared { items });
178        self
179    }
180
181    /// Add a reduction clause
182    ///
183    /// ## Example
184    ///
185    /// ```
186    /// use roup::ir::{DirectiveBuilder, ReductionOperator, Language, SourceLocation};
187    ///
188    /// let directive = DirectiveBuilder::parallel()
189    ///     .reduction(ReductionOperator::Add, &["sum"])
190    ///     .build(SourceLocation::start(), Language::C);
191    /// ```
192    pub fn reduction(mut self, operator: ReductionOperator, vars: &[&'a str]) -> Self {
193        let items = vars
194            .iter()
195            .map(|&v| ClauseItem::Identifier(Identifier::new(v)))
196            .collect();
197        self.clauses.push(ClauseData::Reduction { operator, items });
198        self
199    }
200
201    /// Add a num_threads clause
202    ///
203    /// Note: This creates an unparsed expression. For better control,
204    /// use `num_threads_expr()` with a static string.
205    pub fn num_threads(self, num: i32) -> Self {
206        // Note: For production use, consider requiring static strings
207        // or storing owned strings. This is a convenience method.
208        self.num_threads_expr(Box::leak(Box::new(num.to_string())))
209    }
210
211    /// Add a num_threads clause with expression
212    pub fn num_threads_expr(mut self, expr: &'a str) -> Self {
213        self.clauses.push(ClauseData::NumThreads {
214            num: Expression::unparsed(expr),
215        });
216        self
217    }
218
219    /// Add an if clause
220    pub fn if_clause(mut self, condition: &'a str) -> Self {
221        self.clauses.push(ClauseData::If {
222            directive_name: None,
223            condition: Expression::unparsed(condition),
224        });
225        self
226    }
227
228    /// Add a schedule clause
229    ///
230    /// ## Example
231    ///
232    /// ```
233    /// use roup::ir::{DirectiveBuilder, ScheduleKind, Language, SourceLocation};
234    ///
235    /// let directive = DirectiveBuilder::for_loop()
236    ///     .schedule_simple(ScheduleKind::Static)
237    ///     .build(SourceLocation::start(), Language::C);
238    /// ```
239    pub fn schedule_simple(mut self, kind: ScheduleKind) -> Self {
240        self.clauses.push(ClauseData::Schedule {
241            kind,
242            modifiers: vec![],
243            chunk_size: None,
244        });
245        self
246    }
247
248    /// Add a schedule clause with chunk size expression
249    pub fn schedule(mut self, kind: ScheduleKind, chunk_size: Option<&'a str>) -> Self {
250        self.clauses.push(ClauseData::Schedule {
251            kind,
252            modifiers: vec![],
253            chunk_size: chunk_size.map(Expression::unparsed),
254        });
255        self
256    }
257
258    /// Add a schedule clause with modifiers
259    pub fn schedule_with_modifiers(
260        mut self,
261        kind: ScheduleKind,
262        modifiers: Vec<ScheduleModifier>,
263        chunk_size: Option<&'a str>,
264    ) -> Self {
265        self.clauses.push(ClauseData::Schedule {
266            kind,
267            modifiers,
268            chunk_size: chunk_size.map(Expression::unparsed),
269        });
270        self
271    }
272
273    /// Add a collapse clause with expression
274    pub fn collapse(mut self, n: &'a str) -> Self {
275        self.clauses.push(ClauseData::Collapse {
276            n: Expression::unparsed(n),
277        });
278        self
279    }
280
281    /// Add an ordered clause
282    pub fn ordered(mut self) -> Self {
283        self.clauses.push(ClauseData::Ordered { n: None });
284        self
285    }
286
287    /// Add an ordered clause with parameter
288    pub fn ordered_n(mut self, n: &'a str) -> Self {
289        self.clauses.push(ClauseData::Ordered {
290            n: Some(Expression::unparsed(n)),
291        });
292        self
293    }
294
295    /// Add a nowait clause
296    pub fn nowait(mut self) -> Self {
297        self.clauses
298            .push(ClauseData::Bare(Identifier::new("nowait")));
299        self
300    }
301
302    /// Add a map clause
303    pub fn map(mut self, map_type: MapType, vars: &[&'a str]) -> Self {
304        let items = vars
305            .iter()
306            .map(|&v| ClauseItem::Identifier(Identifier::new(v)))
307            .collect();
308        self.clauses.push(ClauseData::Map {
309            map_type: Some(map_type),
310            mapper: None,
311            items,
312        });
313        self
314    }
315
316    /// Add a depend clause
317    pub fn depend(mut self, depend_type: DependType, vars: &[&'a str]) -> Self {
318        let items = vars
319            .iter()
320            .map(|&v| ClauseItem::Identifier(Identifier::new(v)))
321            .collect();
322        self.clauses.push(ClauseData::Depend { depend_type, items });
323        self
324    }
325
326    /// Add a proc_bind clause
327    pub fn proc_bind(mut self, kind: ProcBind) -> Self {
328        self.clauses.push(ClauseData::ProcBind(kind));
329        self
330    }
331
332    /// Build the DirectiveIR
333    ///
334    /// ## Example
335    ///
336    /// ```
337    /// use roup::ir::{DirectiveBuilder, Language, SourceLocation};
338    ///
339    /// let directive = DirectiveBuilder::parallel()
340    ///     .default_shared()
341    ///     .private(&["x"])
342    ///     .build(SourceLocation::start(), Language::C);
343    ///
344    /// assert_eq!(directive.clauses().len(), 2);
345    /// ```
346    pub fn build(self, location: SourceLocation, language: Language) -> DirectiveIR {
347        DirectiveIR::new(self.kind, &self.name, self.clauses, location, language)
348    }
349}
350
351// ============================================================================
352// Tests
353// ============================================================================
354
355#[cfg(test)]
356mod tests {
357    use super::*;
358
359    #[test]
360    fn test_builder_parallel_simple() {
361        let directive = DirectiveBuilder::parallel().build(SourceLocation::start(), Language::C);
362
363        assert_eq!(directive.kind(), DirectiveKind::Parallel);
364        assert_eq!(directive.clauses().len(), 0);
365    }
366
367    #[test]
368    fn test_builder_parallel_with_default() {
369        let directive = DirectiveBuilder::parallel()
370            .default_shared()
371            .build(SourceLocation::start(), Language::C);
372
373        assert_eq!(directive.clauses().len(), 1);
374        assert!(directive.has_clause(|c| c.is_default()));
375    }
376
377    #[test]
378    fn test_builder_parallel_with_multiple_clauses() {
379        let directive = DirectiveBuilder::parallel()
380            .default_shared()
381            .private(&["x", "y"])
382            .num_threads(4)
383            .build(SourceLocation::start(), Language::C);
384
385        assert_eq!(directive.clauses().len(), 3);
386        assert!(directive.has_clause(|c| c.is_default()));
387        assert!(directive.has_clause(|c| c.is_private()));
388        assert!(directive.has_clause(|c| c.is_num_threads()));
389    }
390
391    #[test]
392    fn test_builder_parallel_for() {
393        let directive = DirectiveBuilder::parallel_for()
394            .schedule(ScheduleKind::Static, Some("16"))
395            .reduction(ReductionOperator::Add, &["sum"])
396            .build(SourceLocation::start(), Language::C);
397
398        assert_eq!(directive.kind(), DirectiveKind::ParallelFor);
399        assert_eq!(directive.clauses().len(), 2);
400    }
401
402    #[test]
403    fn test_builder_for_with_schedule() {
404        let directive = DirectiveBuilder::for_loop()
405            .schedule(ScheduleKind::Dynamic, Some("10"))
406            .collapse("2")
407            .build(SourceLocation::start(), Language::C);
408
409        assert_eq!(directive.kind(), DirectiveKind::For);
410        assert!(directive.has_clause(|c| c.is_schedule()));
411        assert!(directive.has_clause(|c| c.is_collapse()));
412    }
413
414    #[test]
415    fn test_builder_target_with_map() {
416        let directive = DirectiveBuilder::target()
417            .map(MapType::To, &["arr"])
418            .build(SourceLocation::start(), Language::C);
419
420        assert_eq!(directive.kind(), DirectiveKind::Target);
421        assert!(directive.has_clause(|c| c.is_map()));
422    }
423
424    #[test]
425    fn test_builder_task_with_depend() {
426        let directive = DirectiveBuilder::task()
427            .depend(DependType::In, &["x", "y"])
428            .private(&["temp"])
429            .build(SourceLocation::start(), Language::C);
430
431        assert_eq!(directive.kind(), DirectiveKind::Task);
432        assert!(directive.has_clause(|c| c.is_depend()));
433        assert!(directive.has_clause(|c| c.is_private()));
434    }
435
436    #[test]
437    fn test_builder_method_chaining() {
438        let directive = DirectiveBuilder::parallel()
439            .default_shared()
440            .private(&["i", "j"])
441            .shared(&["data"])
442            .reduction(ReductionOperator::Add, &["sum"])
443            .num_threads(8)
444            .if_clause("n > 100")
445            .build(SourceLocation::start(), Language::C);
446
447        assert_eq!(directive.clauses().len(), 6);
448    }
449
450    #[test]
451    fn test_builder_for_with_nowait() {
452        let directive = DirectiveBuilder::for_loop()
453            .schedule_simple(ScheduleKind::Static)
454            .nowait()
455            .build(SourceLocation::start(), Language::C);
456
457        assert_eq!(directive.clauses().len(), 2);
458    }
459
460    #[test]
461    fn test_builder_display_roundtrip() {
462        let directive = DirectiveBuilder::parallel()
463            .default_shared()
464            .private(&["x"])
465            .build(SourceLocation::start(), Language::C);
466
467        let output = directive.to_string();
468        assert!(output.contains("parallel"));
469        assert!(output.contains("default(shared)"));
470        assert!(output.contains("private(x)"));
471    }
472}