1use super::{ClauseData, DirectiveIR, DirectiveKind};
30use std::fmt;
31
32#[derive(Debug, Clone, PartialEq, Eq)]
34pub enum ValidationError {
35 ClauseNotAllowed {
37 clause_name: String,
38 directive: String,
39 reason: String,
40 },
41 ConflictingClauses {
43 clause1: String,
44 clause2: String,
45 reason: String,
46 },
47 MissingRequiredClause {
49 directive: String,
50 required_clause: String,
51 },
52 InvalidCombination {
54 clauses: Vec<String>,
55 reason: String,
56 },
57}
58
59impl fmt::Display for ValidationError {
60 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
61 match self {
62 ValidationError::ClauseNotAllowed {
63 clause_name,
64 directive,
65 reason,
66 } => {
67 write!(
68 f,
69 "Clause '{clause_name}' not allowed on '{directive}' directive: {reason}"
70 )
71 }
72 ValidationError::ConflictingClauses {
73 clause1,
74 clause2,
75 reason,
76 } => {
77 write!(
78 f,
79 "Conflicting clauses '{clause1}' and '{clause2}': {reason}"
80 )
81 }
82 ValidationError::MissingRequiredClause {
83 directive,
84 required_clause,
85 } => {
86 write!(
87 f,
88 "Directive '{directive}' requires clause '{required_clause}'"
89 )
90 }
91 ValidationError::InvalidCombination { clauses, reason } => {
92 write!(
93 f,
94 "Invalid combination of clauses [{}]: {}",
95 clauses.join(", "),
96 reason
97 )
98 }
99 }
100 }
101}
102
103impl std::error::Error for ValidationError {}
104
105pub struct ValidationContext {
107 directive: DirectiveKind,
108}
109
110impl ValidationContext {
111 pub fn new(directive: DirectiveKind) -> Self {
113 Self { directive }
114 }
115
116 pub fn is_clause_allowed(&self, clause: &ClauseData) -> Result<(), ValidationError> {
118 let clause_name = self.clause_name(clause);
120
121 match clause {
122 ClauseData::Bare(name) if name.to_string() == "nowait" => {
124 if self.directive.is_worksharing() || self.directive == DirectiveKind::Target {
125 Ok(())
126 } else {
127 Err(ValidationError::ClauseNotAllowed {
128 clause_name,
129 directive: self.directive.to_string(),
130 reason: "nowait only allowed on worksharing constructs (for, sections, single) or target".to_string(),
131 })
132 }
133 }
134
135 ClauseData::Reduction { .. } => {
137 if self.directive.is_parallel()
138 || self.directive.is_worksharing()
139 || self.directive.is_simd()
140 || self.directive.is_teams()
141 {
142 Ok(())
143 } else {
144 Err(ValidationError::ClauseNotAllowed {
145 clause_name,
146 directive: self.directive.to_string(),
147 reason: "reduction requires parallel, worksharing, simd, or teams context"
148 .to_string(),
149 })
150 }
151 }
152
153 ClauseData::Schedule { .. } => {
155 if self.directive.is_loop() || self.directive.is_worksharing() {
156 Ok(())
157 } else {
158 Err(ValidationError::ClauseNotAllowed {
159 clause_name,
160 directive: self.directive.to_string(),
161 reason:
162 "schedule only allowed on loop constructs (for, parallel for, etc.)"
163 .to_string(),
164 })
165 }
166 }
167
168 ClauseData::NumThreads { .. } => {
170 if self.directive.is_parallel() {
171 Ok(())
172 } else {
173 Err(ValidationError::ClauseNotAllowed {
174 clause_name,
175 directive: self.directive.to_string(),
176 reason: "num_threads only allowed on parallel constructs".to_string(),
177 })
178 }
179 }
180
181 ClauseData::Map { .. } => {
183 if self.directive.is_target() {
184 Ok(())
185 } else {
186 Err(ValidationError::ClauseNotAllowed {
187 clause_name,
188 directive: self.directive.to_string(),
189 reason: "map only allowed on target constructs".to_string(),
190 })
191 }
192 }
193
194 ClauseData::Depend { .. } => {
196 if self.directive.is_task() || self.directive == DirectiveKind::Ordered {
197 Ok(())
198 } else {
199 Err(ValidationError::ClauseNotAllowed {
200 clause_name,
201 directive: self.directive.to_string(),
202 reason: "depend only allowed on task constructs or ordered".to_string(),
203 })
204 }
205 }
206
207 ClauseData::Linear { .. } => {
209 if self.directive.is_simd() || self.directive.is_loop() {
210 Ok(())
211 } else {
212 Err(ValidationError::ClauseNotAllowed {
213 clause_name,
214 directive: self.directive.to_string(),
215 reason: "linear only allowed on simd or loop constructs".to_string(),
216 })
217 }
218 }
219
220 ClauseData::Collapse { .. } => {
222 if self.directive.is_loop() || self.directive.is_worksharing() {
223 Ok(())
224 } else {
225 Err(ValidationError::ClauseNotAllowed {
226 clause_name,
227 directive: self.directive.to_string(),
228 reason: "collapse only allowed on loop constructs".to_string(),
229 })
230 }
231 }
232
233 ClauseData::Ordered { .. } => {
235 if self.directive.is_loop() || self.directive.is_worksharing() {
236 Ok(())
237 } else {
238 Err(ValidationError::ClauseNotAllowed {
239 clause_name,
240 directive: self.directive.to_string(),
241 reason: "ordered only allowed on loop constructs".to_string(),
242 })
243 }
244 }
245
246 ClauseData::ProcBind(_) => {
248 if self.directive.is_parallel() {
249 Ok(())
250 } else {
251 Err(ValidationError::ClauseNotAllowed {
252 clause_name,
253 directive: self.directive.to_string(),
254 reason: "proc_bind only allowed on parallel constructs".to_string(),
255 })
256 }
257 }
258
259 ClauseData::Private { .. }
261 | ClauseData::Firstprivate { .. }
262 | ClauseData::Lastprivate { .. }
263 | ClauseData::Shared { .. } => Ok(()),
264
265 ClauseData::Default(_) => {
267 if self.directive.is_parallel() || self.directive.is_task() {
268 Ok(())
269 } else {
270 Err(ValidationError::ClauseNotAllowed {
271 clause_name,
272 directive: self.directive.to_string(),
273 reason: "default only allowed on parallel or task constructs".to_string(),
274 })
275 }
276 }
277
278 ClauseData::If { .. } => Ok(()),
280
281 ClauseData::Generic { .. } => Ok(()),
283
284 _ => Ok(()),
286 }
287 }
288
289 fn clause_name(&self, clause: &ClauseData) -> String {
291 match clause {
292 ClauseData::Bare(name) => name.to_string(),
293 ClauseData::Private { .. } => "private".to_string(),
294 ClauseData::Firstprivate { .. } => "firstprivate".to_string(),
295 ClauseData::Lastprivate { .. } => "lastprivate".to_string(),
296 ClauseData::Shared { .. } => "shared".to_string(),
297 ClauseData::Default(_) => "default".to_string(),
298 ClauseData::Reduction { .. } => "reduction".to_string(),
299 ClauseData::Map { .. } => "map".to_string(),
300 ClauseData::Schedule { .. } => "schedule".to_string(),
301 ClauseData::Linear { .. } => "linear".to_string(),
302 ClauseData::If { .. } => "if".to_string(),
303 ClauseData::NumThreads { .. } => "num_threads".to_string(),
304 ClauseData::ProcBind(_) => "proc_bind".to_string(),
305 ClauseData::Collapse { .. } => "collapse".to_string(),
306 ClauseData::Ordered { .. } => "ordered".to_string(),
307 ClauseData::Depend { .. } => "depend".to_string(),
308 ClauseData::Generic { name, .. } => name.to_string(),
309 _ => "<unknown>".to_string(),
310 }
311 }
312
313 pub fn validate_all(&self, clauses: &[ClauseData]) -> Result<(), Vec<ValidationError>> {
315 let mut errors = Vec::new();
316
317 for clause in clauses {
319 if let Err(e) = self.is_clause_allowed(clause) {
320 errors.push(e);
321 }
322 }
323
324 if let Err(mut conflicts) = self.check_conflicts(clauses) {
326 errors.append(&mut conflicts);
327 }
328
329 if errors.is_empty() {
330 Ok(())
331 } else {
332 Err(errors)
333 }
334 }
335
336 fn check_conflicts(&self, clauses: &[ClauseData]) -> Result<(), Vec<ValidationError>> {
338 let mut errors = Vec::new();
339
340 let default_count = clauses.iter().filter(|c| c.is_default()).count();
342 if default_count > 1 {
343 errors.push(ValidationError::InvalidCombination {
344 clauses: vec!["default".to_string(); default_count],
345 reason: "only one default clause allowed".to_string(),
346 });
347 }
348
349 let num_threads_count = clauses.iter().filter(|c| c.is_num_threads()).count();
351 if num_threads_count > 1 {
352 errors.push(ValidationError::InvalidCombination {
353 clauses: vec!["num_threads".to_string(); num_threads_count],
354 reason: "only one num_threads clause allowed".to_string(),
355 });
356 }
357
358 let proc_bind_count = clauses.iter().filter(|c| c.is_proc_bind()).count();
360 if proc_bind_count > 1 {
361 errors.push(ValidationError::InvalidCombination {
362 clauses: vec!["proc_bind".to_string(); proc_bind_count],
363 reason: "only one proc_bind clause allowed".to_string(),
364 });
365 }
366
367 let has_ordered = clauses.iter().any(|c| c.is_ordered());
369 let has_auto_runtime = clauses.iter().any(|c| {
370 if let ClauseData::Schedule { kind, .. } = c {
371 matches!(
372 kind,
373 super::ScheduleKind::Auto | super::ScheduleKind::Runtime
374 )
375 } else {
376 false
377 }
378 });
379
380 if has_ordered && has_auto_runtime {
381 errors.push(ValidationError::ConflictingClauses {
382 clause1: "ordered".to_string(),
383 clause2: "schedule(auto/runtime)".to_string(),
384 reason: "ordered not compatible with schedule(auto) or schedule(runtime)"
385 .to_string(),
386 });
387 }
388
389 if errors.is_empty() {
390 Ok(())
391 } else {
392 Err(errors)
393 }
394 }
395}
396
397impl DirectiveIR {
398 pub fn validate(&self) -> Result<(), Vec<ValidationError>> {
417 let context = ValidationContext::new(self.kind());
418 context.validate_all(self.clauses())
419 }
420}
421
422#[cfg(test)]
427mod tests {
428 use super::*;
429 use crate::ir::{
430 ClauseItem, DefaultKind, DependType, Identifier, Language, MapType, ReductionOperator,
431 ScheduleKind, SourceLocation,
432 };
433
434 #[test]
435 fn test_nowait_allowed_on_for() {
436 let context = ValidationContext::new(DirectiveKind::For);
437 let clause = ClauseData::Bare(Identifier::new("nowait"));
438 assert!(context.is_clause_allowed(&clause).is_ok());
439 }
440
441 #[test]
442 fn test_nowait_not_allowed_on_parallel() {
443 let context = ValidationContext::new(DirectiveKind::Parallel);
444 let clause = ClauseData::Bare(Identifier::new("nowait"));
445 assert!(context.is_clause_allowed(&clause).is_err());
446 }
447
448 #[test]
449 fn test_reduction_allowed_on_parallel() {
450 let context = ValidationContext::new(DirectiveKind::Parallel);
451 let clause = ClauseData::Reduction {
452 operator: ReductionOperator::Add,
453 items: vec![ClauseItem::Identifier(Identifier::new("sum"))],
454 };
455 assert!(context.is_clause_allowed(&clause).is_ok());
456 }
457
458 #[test]
459 fn test_reduction_allowed_on_for() {
460 let context = ValidationContext::new(DirectiveKind::For);
461 let clause = ClauseData::Reduction {
462 operator: ReductionOperator::Add,
463 items: vec![ClauseItem::Identifier(Identifier::new("sum"))],
464 };
465 assert!(context.is_clause_allowed(&clause).is_ok());
466 }
467
468 #[test]
469 fn test_schedule_allowed_on_for() {
470 let context = ValidationContext::new(DirectiveKind::For);
471 let clause = ClauseData::Schedule {
472 kind: ScheduleKind::Static,
473 modifiers: vec![],
474 chunk_size: None,
475 };
476 assert!(context.is_clause_allowed(&clause).is_ok());
477 }
478
479 #[test]
480 fn test_schedule_not_allowed_on_parallel() {
481 let context = ValidationContext::new(DirectiveKind::Parallel);
482 let clause = ClauseData::Schedule {
483 kind: ScheduleKind::Static,
484 modifiers: vec![],
485 chunk_size: None,
486 };
487 assert!(context.is_clause_allowed(&clause).is_err());
488 }
489
490 #[test]
491 fn test_num_threads_allowed_on_parallel() {
492 let context = ValidationContext::new(DirectiveKind::Parallel);
493 let clause = ClauseData::NumThreads {
494 num: crate::ir::Expression::unparsed("4"),
495 };
496 assert!(context.is_clause_allowed(&clause).is_ok());
497 }
498
499 #[test]
500 fn test_num_threads_not_allowed_on_for() {
501 let context = ValidationContext::new(DirectiveKind::For);
502 let clause = ClauseData::NumThreads {
503 num: crate::ir::Expression::unparsed("4"),
504 };
505 assert!(context.is_clause_allowed(&clause).is_err());
506 }
507
508 #[test]
509 fn test_map_allowed_on_target() {
510 let context = ValidationContext::new(DirectiveKind::Target);
511 let clause = ClauseData::Map {
512 map_type: Some(MapType::To),
513 mapper: None,
514 items: vec![ClauseItem::Identifier(Identifier::new("arr"))],
515 };
516 assert!(context.is_clause_allowed(&clause).is_ok());
517 }
518
519 #[test]
520 fn test_map_not_allowed_on_parallel() {
521 let context = ValidationContext::new(DirectiveKind::Parallel);
522 let clause = ClauseData::Map {
523 map_type: Some(MapType::To),
524 mapper: None,
525 items: vec![ClauseItem::Identifier(Identifier::new("arr"))],
526 };
527 assert!(context.is_clause_allowed(&clause).is_err());
528 }
529
530 #[test]
531 fn test_depend_allowed_on_task() {
532 let context = ValidationContext::new(DirectiveKind::Task);
533 let clause = ClauseData::Depend {
534 depend_type: DependType::In,
535 items: vec![ClauseItem::Identifier(Identifier::new("x"))],
536 };
537 assert!(context.is_clause_allowed(&clause).is_ok());
538 }
539
540 #[test]
541 fn test_private_allowed_on_most_constructs() {
542 let clause = ClauseData::Private {
543 items: vec![ClauseItem::Identifier(Identifier::new("x"))],
544 };
545
546 assert!(ValidationContext::new(DirectiveKind::Parallel)
547 .is_clause_allowed(&clause)
548 .is_ok());
549 assert!(ValidationContext::new(DirectiveKind::For)
550 .is_clause_allowed(&clause)
551 .is_ok());
552 assert!(ValidationContext::new(DirectiveKind::Task)
553 .is_clause_allowed(&clause)
554 .is_ok());
555 }
556
557 #[test]
558 fn test_multiple_default_clauses_conflict() {
559 let context = ValidationContext::new(DirectiveKind::Parallel);
560 let clauses = vec![
561 ClauseData::Default(DefaultKind::Shared),
562 ClauseData::Default(DefaultKind::None),
563 ];
564
565 let result = context.validate_all(&clauses);
566 assert!(result.is_err());
567 let errors = result.unwrap_err();
568 assert_eq!(errors.len(), 1);
569 assert!(matches!(
570 errors[0],
571 ValidationError::InvalidCombination { .. }
572 ));
573 }
574
575 #[test]
576 fn test_ordered_schedule_auto_conflict() {
577 let context = ValidationContext::new(DirectiveKind::For);
578 let clauses = vec![
579 ClauseData::Ordered { n: None },
580 ClauseData::Schedule {
581 kind: ScheduleKind::Auto,
582 modifiers: vec![],
583 chunk_size: None,
584 },
585 ];
586
587 let result = context.validate_all(&clauses);
588 assert!(result.is_err());
589 let errors = result.unwrap_err();
590 assert!(errors
591 .iter()
592 .any(|e| matches!(e, ValidationError::ConflictingClauses { .. })));
593 }
594
595 #[test]
596 fn test_directive_ir_validate() {
597 let ir = DirectiveIR::new(
598 DirectiveKind::Parallel,
599 "parallel",
600 vec![ClauseData::Default(DefaultKind::Shared)],
601 SourceLocation::start(),
602 Language::C,
603 );
604
605 assert!(ir.validate().is_ok());
606 }
607
608 #[test]
609 fn test_directive_ir_validate_invalid() {
610 let ir = DirectiveIR::new(
611 DirectiveKind::Parallel,
612 "parallel",
613 vec![ClauseData::Bare(Identifier::new("nowait"))],
614 SourceLocation::start(),
615 Language::C,
616 );
617
618 assert!(ir.validate().is_err());
619 }
620}