roup/ir/
validate.rs

1//! IR validation utilities
2//!
3//! This module provides validation for OpenMP directives and clauses,
4//! ensuring semantic correctness beyond just syntax.
5//!
6//! ## Learning Objectives
7//!
8//! - **Context-sensitive validation**: Rules depend on directive type
9//! - **Semantic checking**: Beyond syntax, check meaning and compatibility
10//! - **Error reporting**: Clear, actionable error messages
11//! - **Builder pattern**: Fluent API for constructing valid IR
12//!
13//! ## Validation Levels
14//!
15//! 1. **Syntax validation**: Already handled by parser
16//! 2. **Structural validation**: Clause exists, has required parts
17//! 3. **Semantic validation**: Clause allowed for this directive
18//! 4. **Consistency validation**: Clauses don't conflict with each other
19//!
20//! ## Example
21//!
22//! ```
23//! use roup::ir::{DirectiveKind, ClauseData, Identifier, ValidationContext};
24//!
25//! let context = ValidationContext::new(DirectiveKind::For);
26//! assert!(context.is_clause_allowed(&ClauseData::Bare(Identifier::new("nowait"))).is_ok());
27//! ```
28
29use super::{ClauseData, DirectiveIR, DirectiveKind};
30use std::fmt;
31
32/// Validation error types
33#[derive(Debug, Clone, PartialEq, Eq)]
34pub enum ValidationError {
35    /// Clause not allowed on this directive
36    ClauseNotAllowed {
37        clause_name: String,
38        directive: String,
39        reason: String,
40    },
41    /// Conflicting clauses
42    ConflictingClauses {
43        clause1: String,
44        clause2: String,
45        reason: String,
46    },
47    /// Missing required clause
48    MissingRequiredClause {
49        directive: String,
50        required_clause: String,
51    },
52    /// Invalid clause combination
53    InvalidCombination {
54        clauses: Vec<String>,
55        reason: String,
56    },
57}
58
59impl fmt::Display for ValidationError {
60    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
61        match self {
62            ValidationError::ClauseNotAllowed {
63                clause_name,
64                directive,
65                reason,
66            } => {
67                write!(
68                    f,
69                    "Clause '{clause_name}' not allowed on '{directive}' directive: {reason}"
70                )
71            }
72            ValidationError::ConflictingClauses {
73                clause1,
74                clause2,
75                reason,
76            } => {
77                write!(
78                    f,
79                    "Conflicting clauses '{clause1}' and '{clause2}': {reason}"
80                )
81            }
82            ValidationError::MissingRequiredClause {
83                directive,
84                required_clause,
85            } => {
86                write!(
87                    f,
88                    "Directive '{directive}' requires clause '{required_clause}'"
89                )
90            }
91            ValidationError::InvalidCombination { clauses, reason } => {
92                write!(
93                    f,
94                    "Invalid combination of clauses [{}]: {}",
95                    clauses.join(", "),
96                    reason
97                )
98            }
99        }
100    }
101}
102
103impl std::error::Error for ValidationError {}
104
105/// Validation context for checking clause compatibility
106pub struct ValidationContext {
107    directive: DirectiveKind,
108}
109
110impl ValidationContext {
111    /// Create a new validation context for a directive
112    pub fn new(directive: DirectiveKind) -> Self {
113        Self { directive }
114    }
115
116    /// Check if a clause is allowed on this directive
117    pub fn is_clause_allowed(&self, clause: &ClauseData) -> Result<(), ValidationError> {
118        // Get clause name for error reporting
119        let clause_name = self.clause_name(clause);
120
121        match clause {
122            // nowait is only for worksharing, not parallel
123            ClauseData::Bare(name) if name.to_string() == "nowait" => {
124                if self.directive.is_worksharing() || self.directive == DirectiveKind::Target {
125                    Ok(())
126                } else {
127                    Err(ValidationError::ClauseNotAllowed {
128                        clause_name,
129                        directive: self.directive.to_string(),
130                        reason: "nowait only allowed on worksharing constructs (for, sections, single) or target".to_string(),
131                    })
132                }
133            }
134
135            // reduction requires parallel or worksharing or simd
136            ClauseData::Reduction { .. } => {
137                if self.directive.is_parallel()
138                    || self.directive.is_worksharing()
139                    || self.directive.is_simd()
140                    || self.directive.is_teams()
141                {
142                    Ok(())
143                } else {
144                    Err(ValidationError::ClauseNotAllowed {
145                        clause_name,
146                        directive: self.directive.to_string(),
147                        reason: "reduction requires parallel, worksharing, simd, or teams context"
148                            .to_string(),
149                    })
150                }
151            }
152
153            // schedule is only for loop constructs
154            ClauseData::Schedule { .. } => {
155                if self.directive.is_loop() || self.directive.is_worksharing() {
156                    Ok(())
157                } else {
158                    Err(ValidationError::ClauseNotAllowed {
159                        clause_name,
160                        directive: self.directive.to_string(),
161                        reason:
162                            "schedule only allowed on loop constructs (for, parallel for, etc.)"
163                                .to_string(),
164                    })
165                }
166            }
167
168            // num_threads is for parallel constructs
169            ClauseData::NumThreads { .. } => {
170                if self.directive.is_parallel() {
171                    Ok(())
172                } else {
173                    Err(ValidationError::ClauseNotAllowed {
174                        clause_name,
175                        directive: self.directive.to_string(),
176                        reason: "num_threads only allowed on parallel constructs".to_string(),
177                    })
178                }
179            }
180
181            // map is for target constructs
182            ClauseData::Map { .. } => {
183                if self.directive.is_target() {
184                    Ok(())
185                } else {
186                    Err(ValidationError::ClauseNotAllowed {
187                        clause_name,
188                        directive: self.directive.to_string(),
189                        reason: "map only allowed on target constructs".to_string(),
190                    })
191                }
192            }
193
194            // depend is for task constructs
195            ClauseData::Depend { .. } => {
196                if self.directive.is_task() || self.directive == DirectiveKind::Ordered {
197                    Ok(())
198                } else {
199                    Err(ValidationError::ClauseNotAllowed {
200                        clause_name,
201                        directive: self.directive.to_string(),
202                        reason: "depend only allowed on task constructs or ordered".to_string(),
203                    })
204                }
205            }
206
207            // linear is for simd/loop constructs
208            ClauseData::Linear { .. } => {
209                if self.directive.is_simd() || self.directive.is_loop() {
210                    Ok(())
211                } else {
212                    Err(ValidationError::ClauseNotAllowed {
213                        clause_name,
214                        directive: self.directive.to_string(),
215                        reason: "linear only allowed on simd or loop constructs".to_string(),
216                    })
217                }
218            }
219
220            // collapse is for loop constructs
221            ClauseData::Collapse { .. } => {
222                if self.directive.is_loop() || self.directive.is_worksharing() {
223                    Ok(())
224                } else {
225                    Err(ValidationError::ClauseNotAllowed {
226                        clause_name,
227                        directive: self.directive.to_string(),
228                        reason: "collapse only allowed on loop constructs".to_string(),
229                    })
230                }
231            }
232
233            // ordered is for loop constructs
234            ClauseData::Ordered { .. } => {
235                if self.directive.is_loop() || self.directive.is_worksharing() {
236                    Ok(())
237                } else {
238                    Err(ValidationError::ClauseNotAllowed {
239                        clause_name,
240                        directive: self.directive.to_string(),
241                        reason: "ordered only allowed on loop constructs".to_string(),
242                    })
243                }
244            }
245
246            // proc_bind is for parallel constructs
247            ClauseData::ProcBind(_) => {
248                if self.directive.is_parallel() {
249                    Ok(())
250                } else {
251                    Err(ValidationError::ClauseNotAllowed {
252                        clause_name,
253                        directive: self.directive.to_string(),
254                        reason: "proc_bind only allowed on parallel constructs".to_string(),
255                    })
256                }
257            }
258
259            // Data-sharing clauses (private, shared, etc.) allowed on most constructs
260            ClauseData::Private { .. }
261            | ClauseData::Firstprivate { .. }
262            | ClauseData::Lastprivate { .. }
263            | ClauseData::Shared { .. } => Ok(()),
264
265            // Default clause for parallel and task
266            ClauseData::Default(_) => {
267                if self.directive.is_parallel() || self.directive.is_task() {
268                    Ok(())
269                } else {
270                    Err(ValidationError::ClauseNotAllowed {
271                        clause_name,
272                        directive: self.directive.to_string(),
273                        reason: "default only allowed on parallel or task constructs".to_string(),
274                    })
275                }
276            }
277
278            // if clause allowed on most constructs
279            ClauseData::If { .. } => Ok(()),
280
281            // Generic clauses we don't validate yet
282            ClauseData::Generic { .. } => Ok(()),
283
284            // Other clauses default to allowed
285            _ => Ok(()),
286        }
287    }
288
289    /// Get a displayable name for a clause
290    fn clause_name(&self, clause: &ClauseData) -> String {
291        match clause {
292            ClauseData::Bare(name) => name.to_string(),
293            ClauseData::Private { .. } => "private".to_string(),
294            ClauseData::Firstprivate { .. } => "firstprivate".to_string(),
295            ClauseData::Lastprivate { .. } => "lastprivate".to_string(),
296            ClauseData::Shared { .. } => "shared".to_string(),
297            ClauseData::Default(_) => "default".to_string(),
298            ClauseData::Reduction { .. } => "reduction".to_string(),
299            ClauseData::Map { .. } => "map".to_string(),
300            ClauseData::Schedule { .. } => "schedule".to_string(),
301            ClauseData::Linear { .. } => "linear".to_string(),
302            ClauseData::If { .. } => "if".to_string(),
303            ClauseData::NumThreads { .. } => "num_threads".to_string(),
304            ClauseData::ProcBind(_) => "proc_bind".to_string(),
305            ClauseData::Collapse { .. } => "collapse".to_string(),
306            ClauseData::Ordered { .. } => "ordered".to_string(),
307            ClauseData::Depend { .. } => "depend".to_string(),
308            ClauseData::Generic { name, .. } => name.to_string(),
309            _ => "<unknown>".to_string(),
310        }
311    }
312
313    /// Validate all clauses in a directive
314    pub fn validate_all(&self, clauses: &[ClauseData]) -> Result<(), Vec<ValidationError>> {
315        let mut errors = Vec::new();
316
317        // Check each clause individually
318        for clause in clauses {
319            if let Err(e) = self.is_clause_allowed(clause) {
320                errors.push(e);
321            }
322        }
323
324        // Check for conflicting clauses
325        if let Err(mut conflicts) = self.check_conflicts(clauses) {
326            errors.append(&mut conflicts);
327        }
328
329        if errors.is_empty() {
330            Ok(())
331        } else {
332            Err(errors)
333        }
334    }
335
336    /// Check for conflicting clauses
337    fn check_conflicts(&self, clauses: &[ClauseData]) -> Result<(), Vec<ValidationError>> {
338        let mut errors = Vec::new();
339
340        // Check for multiple default clauses
341        let default_count = clauses.iter().filter(|c| c.is_default()).count();
342        if default_count > 1 {
343            errors.push(ValidationError::InvalidCombination {
344                clauses: vec!["default".to_string(); default_count],
345                reason: "only one default clause allowed".to_string(),
346            });
347        }
348
349        // Check for multiple num_threads clauses
350        let num_threads_count = clauses.iter().filter(|c| c.is_num_threads()).count();
351        if num_threads_count > 1 {
352            errors.push(ValidationError::InvalidCombination {
353                clauses: vec!["num_threads".to_string(); num_threads_count],
354                reason: "only one num_threads clause allowed".to_string(),
355            });
356        }
357
358        // Check for multiple proc_bind clauses
359        let proc_bind_count = clauses.iter().filter(|c| c.is_proc_bind()).count();
360        if proc_bind_count > 1 {
361            errors.push(ValidationError::InvalidCombination {
362                clauses: vec!["proc_bind".to_string(); proc_bind_count],
363                reason: "only one proc_bind clause allowed".to_string(),
364            });
365        }
366
367        // Check for ordered and schedule(auto/runtime) conflict
368        let has_ordered = clauses.iter().any(|c| c.is_ordered());
369        let has_auto_runtime = clauses.iter().any(|c| {
370            if let ClauseData::Schedule { kind, .. } = c {
371                matches!(
372                    kind,
373                    super::ScheduleKind::Auto | super::ScheduleKind::Runtime
374                )
375            } else {
376                false
377            }
378        });
379
380        if has_ordered && has_auto_runtime {
381            errors.push(ValidationError::ConflictingClauses {
382                clause1: "ordered".to_string(),
383                clause2: "schedule(auto/runtime)".to_string(),
384                reason: "ordered not compatible with schedule(auto) or schedule(runtime)"
385                    .to_string(),
386            });
387        }
388
389        if errors.is_empty() {
390            Ok(())
391        } else {
392            Err(errors)
393        }
394    }
395}
396
397impl DirectiveIR {
398    /// Validate this directive and its clauses
399    ///
400    /// ## Example
401    ///
402    /// ```
403    /// use roup::ir::{DirectiveIR, DirectiveKind, ClauseData, DefaultKind, Language, SourceLocation};
404    ///
405    /// let ir = DirectiveIR::new(
406    ///     DirectiveKind::Parallel,
407    ///     "parallel",
408    ///     vec![ClauseData::Default(DefaultKind::Shared)],
409    ///     SourceLocation::start(),
410    ///     Language::C,
411    /// );
412    ///
413    /// // This will validate successfully
414    /// assert!(ir.validate().is_ok());
415    /// ```
416    pub fn validate(&self) -> Result<(), Vec<ValidationError>> {
417        let context = ValidationContext::new(self.kind());
418        context.validate_all(self.clauses())
419    }
420}
421
422// ============================================================================
423// Tests
424// ============================================================================
425
426#[cfg(test)]
427mod tests {
428    use super::*;
429    use crate::ir::{
430        ClauseItem, DefaultKind, DependType, Identifier, Language, MapType, ReductionOperator,
431        ScheduleKind, SourceLocation,
432    };
433
434    #[test]
435    fn test_nowait_allowed_on_for() {
436        let context = ValidationContext::new(DirectiveKind::For);
437        let clause = ClauseData::Bare(Identifier::new("nowait"));
438        assert!(context.is_clause_allowed(&clause).is_ok());
439    }
440
441    #[test]
442    fn test_nowait_not_allowed_on_parallel() {
443        let context = ValidationContext::new(DirectiveKind::Parallel);
444        let clause = ClauseData::Bare(Identifier::new("nowait"));
445        assert!(context.is_clause_allowed(&clause).is_err());
446    }
447
448    #[test]
449    fn test_reduction_allowed_on_parallel() {
450        let context = ValidationContext::new(DirectiveKind::Parallel);
451        let clause = ClauseData::Reduction {
452            operator: ReductionOperator::Add,
453            items: vec![ClauseItem::Identifier(Identifier::new("sum"))],
454        };
455        assert!(context.is_clause_allowed(&clause).is_ok());
456    }
457
458    #[test]
459    fn test_reduction_allowed_on_for() {
460        let context = ValidationContext::new(DirectiveKind::For);
461        let clause = ClauseData::Reduction {
462            operator: ReductionOperator::Add,
463            items: vec![ClauseItem::Identifier(Identifier::new("sum"))],
464        };
465        assert!(context.is_clause_allowed(&clause).is_ok());
466    }
467
468    #[test]
469    fn test_schedule_allowed_on_for() {
470        let context = ValidationContext::new(DirectiveKind::For);
471        let clause = ClauseData::Schedule {
472            kind: ScheduleKind::Static,
473            modifiers: vec![],
474            chunk_size: None,
475        };
476        assert!(context.is_clause_allowed(&clause).is_ok());
477    }
478
479    #[test]
480    fn test_schedule_not_allowed_on_parallel() {
481        let context = ValidationContext::new(DirectiveKind::Parallel);
482        let clause = ClauseData::Schedule {
483            kind: ScheduleKind::Static,
484            modifiers: vec![],
485            chunk_size: None,
486        };
487        assert!(context.is_clause_allowed(&clause).is_err());
488    }
489
490    #[test]
491    fn test_num_threads_allowed_on_parallel() {
492        let context = ValidationContext::new(DirectiveKind::Parallel);
493        let clause = ClauseData::NumThreads {
494            num: crate::ir::Expression::unparsed("4"),
495        };
496        assert!(context.is_clause_allowed(&clause).is_ok());
497    }
498
499    #[test]
500    fn test_num_threads_not_allowed_on_for() {
501        let context = ValidationContext::new(DirectiveKind::For);
502        let clause = ClauseData::NumThreads {
503            num: crate::ir::Expression::unparsed("4"),
504        };
505        assert!(context.is_clause_allowed(&clause).is_err());
506    }
507
508    #[test]
509    fn test_map_allowed_on_target() {
510        let context = ValidationContext::new(DirectiveKind::Target);
511        let clause = ClauseData::Map {
512            map_type: Some(MapType::To),
513            mapper: None,
514            items: vec![ClauseItem::Identifier(Identifier::new("arr"))],
515        };
516        assert!(context.is_clause_allowed(&clause).is_ok());
517    }
518
519    #[test]
520    fn test_map_not_allowed_on_parallel() {
521        let context = ValidationContext::new(DirectiveKind::Parallel);
522        let clause = ClauseData::Map {
523            map_type: Some(MapType::To),
524            mapper: None,
525            items: vec![ClauseItem::Identifier(Identifier::new("arr"))],
526        };
527        assert!(context.is_clause_allowed(&clause).is_err());
528    }
529
530    #[test]
531    fn test_depend_allowed_on_task() {
532        let context = ValidationContext::new(DirectiveKind::Task);
533        let clause = ClauseData::Depend {
534            depend_type: DependType::In,
535            items: vec![ClauseItem::Identifier(Identifier::new("x"))],
536        };
537        assert!(context.is_clause_allowed(&clause).is_ok());
538    }
539
540    #[test]
541    fn test_private_allowed_on_most_constructs() {
542        let clause = ClauseData::Private {
543            items: vec![ClauseItem::Identifier(Identifier::new("x"))],
544        };
545
546        assert!(ValidationContext::new(DirectiveKind::Parallel)
547            .is_clause_allowed(&clause)
548            .is_ok());
549        assert!(ValidationContext::new(DirectiveKind::For)
550            .is_clause_allowed(&clause)
551            .is_ok());
552        assert!(ValidationContext::new(DirectiveKind::Task)
553            .is_clause_allowed(&clause)
554            .is_ok());
555    }
556
557    #[test]
558    fn test_multiple_default_clauses_conflict() {
559        let context = ValidationContext::new(DirectiveKind::Parallel);
560        let clauses = vec![
561            ClauseData::Default(DefaultKind::Shared),
562            ClauseData::Default(DefaultKind::None),
563        ];
564
565        let result = context.validate_all(&clauses);
566        assert!(result.is_err());
567        let errors = result.unwrap_err();
568        assert_eq!(errors.len(), 1);
569        assert!(matches!(
570            errors[0],
571            ValidationError::InvalidCombination { .. }
572        ));
573    }
574
575    #[test]
576    fn test_ordered_schedule_auto_conflict() {
577        let context = ValidationContext::new(DirectiveKind::For);
578        let clauses = vec![
579            ClauseData::Ordered { n: None },
580            ClauseData::Schedule {
581                kind: ScheduleKind::Auto,
582                modifiers: vec![],
583                chunk_size: None,
584            },
585        ];
586
587        let result = context.validate_all(&clauses);
588        assert!(result.is_err());
589        let errors = result.unwrap_err();
590        assert!(errors
591            .iter()
592            .any(|e| matches!(e, ValidationError::ConflictingClauses { .. })));
593    }
594
595    #[test]
596    fn test_directive_ir_validate() {
597        let ir = DirectiveIR::new(
598            DirectiveKind::Parallel,
599            "parallel",
600            vec![ClauseData::Default(DefaultKind::Shared)],
601            SourceLocation::start(),
602            Language::C,
603        );
604
605        assert!(ir.validate().is_ok());
606    }
607
608    #[test]
609    fn test_directive_ir_validate_invalid() {
610        let ir = DirectiveIR::new(
611            DirectiveKind::Parallel,
612            "parallel",
613            vec![ClauseData::Bare(Identifier::new("nowait"))],
614            SourceLocation::start(),
615            Language::C,
616        );
617
618        assert!(ir.validate().is_err());
619    }
620}