1use super::{
30 ClauseData, ClauseItem, DefaultKind, DependType, DirectiveIR, DirectiveKind, Expression,
31 Identifier, Language, MapType, ProcBind, ReductionOperator, ScheduleKind, ScheduleModifier,
32 SourceLocation,
33};
34
35pub struct DirectiveBuilder {
37 kind: DirectiveKind,
38 name: String,
39 clauses: Vec<ClauseData>,
40}
41
42impl<'a> DirectiveBuilder {
43 pub fn parallel() -> Self {
57 Self {
58 kind: DirectiveKind::Parallel,
59 name: "parallel".to_string(),
60 clauses: Vec::new(),
61 }
62 }
63
64 pub fn parallel_for() -> Self {
66 Self {
67 kind: DirectiveKind::ParallelFor,
68 name: "parallel for".to_string(),
69 clauses: Vec::new(),
70 }
71 }
72
73 pub fn for_loop() -> Self {
75 Self {
76 kind: DirectiveKind::For,
77 name: "for".to_string(),
78 clauses: Vec::new(),
79 }
80 }
81
82 pub fn task() -> Self {
84 Self {
85 kind: DirectiveKind::Task,
86 name: "task".to_string(),
87 clauses: Vec::new(),
88 }
89 }
90
91 pub fn target() -> Self {
93 Self {
94 kind: DirectiveKind::Target,
95 name: "target".to_string(),
96 clauses: Vec::new(),
97 }
98 }
99
100 pub fn teams() -> Self {
102 Self {
103 kind: DirectiveKind::Teams,
104 name: "teams".to_string(),
105 clauses: Vec::new(),
106 }
107 }
108
109 pub fn new(kind: DirectiveKind) -> Self {
111 let name = format!("{kind:?}").to_lowercase();
112 Self {
113 kind,
114 name,
115 clauses: Vec::new(),
116 }
117 }
118
119 pub fn default_shared(mut self) -> Self {
125 self.clauses.push(ClauseData::Default(DefaultKind::Shared));
126 self
127 }
128
129 pub fn default_none(mut self) -> Self {
131 self.clauses.push(ClauseData::Default(DefaultKind::None));
132 self
133 }
134
135 pub fn default(mut self, kind: DefaultKind) -> Self {
137 self.clauses.push(ClauseData::Default(kind));
138 self
139 }
140
141 pub fn private(mut self, vars: &[&'a str]) -> Self {
153 let items = vars
154 .iter()
155 .map(|&v| ClauseItem::Identifier(Identifier::new(v)))
156 .collect();
157 self.clauses.push(ClauseData::Private { items });
158 self
159 }
160
161 pub fn firstprivate(mut self, vars: &[&'a str]) -> Self {
163 let items = vars
164 .iter()
165 .map(|&v| ClauseItem::Identifier(Identifier::new(v)))
166 .collect();
167 self.clauses.push(ClauseData::Firstprivate { items });
168 self
169 }
170
171 pub fn shared(mut self, vars: &[&'a str]) -> Self {
173 let items = vars
174 .iter()
175 .map(|&v| ClauseItem::Identifier(Identifier::new(v)))
176 .collect();
177 self.clauses.push(ClauseData::Shared { items });
178 self
179 }
180
181 pub fn reduction(mut self, operator: ReductionOperator, vars: &[&'a str]) -> Self {
193 let items = vars
194 .iter()
195 .map(|&v| ClauseItem::Identifier(Identifier::new(v)))
196 .collect();
197 self.clauses.push(ClauseData::Reduction { operator, items });
198 self
199 }
200
201 pub fn num_threads(self, num: i32) -> Self {
206 self.num_threads_expr(Box::leak(Box::new(num.to_string())))
209 }
210
211 pub fn num_threads_expr(mut self, expr: &'a str) -> Self {
213 self.clauses.push(ClauseData::NumThreads {
214 num: Expression::unparsed(expr),
215 });
216 self
217 }
218
219 pub fn if_clause(mut self, condition: &'a str) -> Self {
221 self.clauses.push(ClauseData::If {
222 directive_name: None,
223 condition: Expression::unparsed(condition),
224 });
225 self
226 }
227
228 pub fn schedule_simple(mut self, kind: ScheduleKind) -> Self {
240 self.clauses.push(ClauseData::Schedule {
241 kind,
242 modifiers: vec![],
243 chunk_size: None,
244 });
245 self
246 }
247
248 pub fn schedule(mut self, kind: ScheduleKind, chunk_size: Option<&'a str>) -> Self {
250 self.clauses.push(ClauseData::Schedule {
251 kind,
252 modifiers: vec![],
253 chunk_size: chunk_size.map(Expression::unparsed),
254 });
255 self
256 }
257
258 pub fn schedule_with_modifiers(
260 mut self,
261 kind: ScheduleKind,
262 modifiers: Vec<ScheduleModifier>,
263 chunk_size: Option<&'a str>,
264 ) -> Self {
265 self.clauses.push(ClauseData::Schedule {
266 kind,
267 modifiers,
268 chunk_size: chunk_size.map(Expression::unparsed),
269 });
270 self
271 }
272
273 pub fn collapse(mut self, n: &'a str) -> Self {
275 self.clauses.push(ClauseData::Collapse {
276 n: Expression::unparsed(n),
277 });
278 self
279 }
280
281 pub fn ordered(mut self) -> Self {
283 self.clauses.push(ClauseData::Ordered { n: None });
284 self
285 }
286
287 pub fn ordered_n(mut self, n: &'a str) -> Self {
289 self.clauses.push(ClauseData::Ordered {
290 n: Some(Expression::unparsed(n)),
291 });
292 self
293 }
294
295 pub fn nowait(mut self) -> Self {
297 self.clauses
298 .push(ClauseData::Bare(Identifier::new("nowait")));
299 self
300 }
301
302 pub fn map(mut self, map_type: MapType, vars: &[&'a str]) -> Self {
304 let items = vars
305 .iter()
306 .map(|&v| ClauseItem::Identifier(Identifier::new(v)))
307 .collect();
308 self.clauses.push(ClauseData::Map {
309 map_type: Some(map_type),
310 mapper: None,
311 items,
312 });
313 self
314 }
315
316 pub fn depend(mut self, depend_type: DependType, vars: &[&'a str]) -> Self {
318 let items = vars
319 .iter()
320 .map(|&v| ClauseItem::Identifier(Identifier::new(v)))
321 .collect();
322 self.clauses.push(ClauseData::Depend { depend_type, items });
323 self
324 }
325
326 pub fn proc_bind(mut self, kind: ProcBind) -> Self {
328 self.clauses.push(ClauseData::ProcBind(kind));
329 self
330 }
331
332 pub fn build(self, location: SourceLocation, language: Language) -> DirectiveIR {
347 DirectiveIR::new(self.kind, &self.name, self.clauses, location, language)
348 }
349}
350
351#[cfg(test)]
356mod tests {
357 use super::*;
358
359 #[test]
360 fn test_builder_parallel_simple() {
361 let directive = DirectiveBuilder::parallel().build(SourceLocation::start(), Language::C);
362
363 assert_eq!(directive.kind(), DirectiveKind::Parallel);
364 assert_eq!(directive.clauses().len(), 0);
365 }
366
367 #[test]
368 fn test_builder_parallel_with_default() {
369 let directive = DirectiveBuilder::parallel()
370 .default_shared()
371 .build(SourceLocation::start(), Language::C);
372
373 assert_eq!(directive.clauses().len(), 1);
374 assert!(directive.has_clause(|c| c.is_default()));
375 }
376
377 #[test]
378 fn test_builder_parallel_with_multiple_clauses() {
379 let directive = DirectiveBuilder::parallel()
380 .default_shared()
381 .private(&["x", "y"])
382 .num_threads(4)
383 .build(SourceLocation::start(), Language::C);
384
385 assert_eq!(directive.clauses().len(), 3);
386 assert!(directive.has_clause(|c| c.is_default()));
387 assert!(directive.has_clause(|c| c.is_private()));
388 assert!(directive.has_clause(|c| c.is_num_threads()));
389 }
390
391 #[test]
392 fn test_builder_parallel_for() {
393 let directive = DirectiveBuilder::parallel_for()
394 .schedule(ScheduleKind::Static, Some("16"))
395 .reduction(ReductionOperator::Add, &["sum"])
396 .build(SourceLocation::start(), Language::C);
397
398 assert_eq!(directive.kind(), DirectiveKind::ParallelFor);
399 assert_eq!(directive.clauses().len(), 2);
400 }
401
402 #[test]
403 fn test_builder_for_with_schedule() {
404 let directive = DirectiveBuilder::for_loop()
405 .schedule(ScheduleKind::Dynamic, Some("10"))
406 .collapse("2")
407 .build(SourceLocation::start(), Language::C);
408
409 assert_eq!(directive.kind(), DirectiveKind::For);
410 assert!(directive.has_clause(|c| c.is_schedule()));
411 assert!(directive.has_clause(|c| c.is_collapse()));
412 }
413
414 #[test]
415 fn test_builder_target_with_map() {
416 let directive = DirectiveBuilder::target()
417 .map(MapType::To, &["arr"])
418 .build(SourceLocation::start(), Language::C);
419
420 assert_eq!(directive.kind(), DirectiveKind::Target);
421 assert!(directive.has_clause(|c| c.is_map()));
422 }
423
424 #[test]
425 fn test_builder_task_with_depend() {
426 let directive = DirectiveBuilder::task()
427 .depend(DependType::In, &["x", "y"])
428 .private(&["temp"])
429 .build(SourceLocation::start(), Language::C);
430
431 assert_eq!(directive.kind(), DirectiveKind::Task);
432 assert!(directive.has_clause(|c| c.is_depend()));
433 assert!(directive.has_clause(|c| c.is_private()));
434 }
435
436 #[test]
437 fn test_builder_method_chaining() {
438 let directive = DirectiveBuilder::parallel()
439 .default_shared()
440 .private(&["i", "j"])
441 .shared(&["data"])
442 .reduction(ReductionOperator::Add, &["sum"])
443 .num_threads(8)
444 .if_clause("n > 100")
445 .build(SourceLocation::start(), Language::C);
446
447 assert_eq!(directive.clauses().len(), 6);
448 }
449
450 #[test]
451 fn test_builder_for_with_nowait() {
452 let directive = DirectiveBuilder::for_loop()
453 .schedule_simple(ScheduleKind::Static)
454 .nowait()
455 .build(SourceLocation::start(), Language::C);
456
457 assert_eq!(directive.clauses().len(), 2);
458 }
459
460 #[test]
461 fn test_builder_display_roundtrip() {
462 let directive = DirectiveBuilder::parallel()
463 .default_shared()
464 .private(&["x"])
465 .build(SourceLocation::start(), Language::C);
466
467 let output = directive.to_string();
468 assert!(output.contains("parallel"));
469 assert!(output.contains("default(shared)"));
470 assert!(output.contains("private(x)"));
471 }
472}