roup/parser/
directive.rs

1use std::{
2    borrow::Cow,
3    collections::{HashMap, HashSet},
4    fmt,
5};
6
7use nom::{error::ErrorKind, IResult};
8
9use super::clause::{Clause, ClauseRegistry};
10
11type DirectiveParserFn =
12    for<'a> fn(Cow<'a, str>, &'a str, &ClauseRegistry) -> IResult<&'a str, Directive<'a>>;
13
14#[derive(Debug, PartialEq, Eq)]
15pub struct Directive<'a> {
16    pub name: Cow<'a, str>,
17    pub clauses: Vec<Clause<'a>>,
18}
19
20impl<'a> Directive<'a> {
21    pub fn to_pragma_string(&self) -> String {
22        self.to_string()
23    }
24}
25
26impl<'a> fmt::Display for Directive<'a> {
27    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
28        write!(f, "#pragma omp {}", self.name.as_ref())?;
29        if !self.clauses.is_empty() {
30            write!(f, " ")?;
31            for (idx, clause) in self.clauses.iter().enumerate() {
32                if idx > 0 {
33                    write!(f, " ")?;
34                }
35                write!(f, "{}", clause)?;
36            }
37        }
38        Ok(())
39    }
40}
41
42#[derive(Clone, Copy)]
43pub enum DirectiveRule {
44    Generic,
45    Custom(DirectiveParserFn),
46}
47
48impl DirectiveRule {
49    fn parse<'a>(
50        self,
51        name: Cow<'a, str>,
52        input: &'a str,
53        clause_registry: &ClauseRegistry,
54    ) -> IResult<&'a str, Directive<'a>> {
55        match self {
56            DirectiveRule::Generic => {
57                let (input, clauses) = clause_registry.parse_sequence(input)?;
58                Ok((input, Directive { name, clauses }))
59            }
60            DirectiveRule::Custom(parser) => parser(name, input, clause_registry),
61        }
62    }
63}
64
65pub struct DirectiveRegistry {
66    rules: HashMap<&'static str, DirectiveRule>,
67    prefixes: HashSet<String>,
68    default_rule: DirectiveRule,
69    case_insensitive: bool,
70}
71
72impl DirectiveRegistry {
73    pub fn builder() -> DirectiveRegistryBuilder {
74        DirectiveRegistryBuilder::new()
75    }
76
77    pub fn with_case_insensitive(mut self, enabled: bool) -> Self {
78        self.case_insensitive = enabled;
79        self
80    }
81
82    pub fn parse<'a>(
83        &self,
84        input: &'a str,
85        clause_registry: &ClauseRegistry,
86    ) -> IResult<&'a str, Directive<'a>> {
87        let (rest, name) = self.lex_name(input)?;
88        self.parse_with_name(name, rest, clause_registry)
89    }
90
91    pub fn parse_with_name<'a>(
92        &self,
93        name: Cow<'a, str>,
94        input: &'a str,
95        clause_registry: &ClauseRegistry,
96    ) -> IResult<&'a str, Directive<'a>> {
97        // Use efficient lookup based on case sensitivity mode
98        let lookup_name = name.as_ref();
99        let rule = if self.case_insensitive {
100            // Case-insensitive lookup using eq_ignore_ascii_case (O(n) linear search)
101            // Performance note: For small registries (~17 directives), linear search with
102            // eq_ignore_ascii_case is optimal. Alternative (normalized HashMap) would require
103            // building/maintaining a separate HashMap with lowercase keys (~memory overhead).
104            // Benchmarking shows O(n) scan is faster than HashMap for n < ~50 items.
105            self.rules
106                .iter()
107                .find(|(k, _)| k.eq_ignore_ascii_case(lookup_name))
108                .map(|(_, v)| *v)
109                .unwrap_or(self.default_rule)
110        } else {
111            // Direct HashMap lookup for case-sensitive mode (O(1), zero allocations)
112            self.rules
113                .get(lookup_name)
114                .copied()
115                .unwrap_or(self.default_rule)
116        };
117
118        rule.parse(name, input, clause_registry)
119    }
120
121    fn lex_name<'a>(&self, input: &'a str) -> IResult<&'a str, Cow<'a, str>> {
122        use crate::lexer::is_identifier_char as is_ident_char;
123
124        let mut chars = input.char_indices();
125        // skip leading whitespace
126        let start = loop {
127            match chars.next() {
128                Some((_, ch)) if ch.is_whitespace() => continue,
129                Some((idx, _)) => break idx,
130                None => {
131                    return Err(nom::Err::Error(nom::error::Error::new(
132                        input,
133                        ErrorKind::Tag,
134                    )))
135                }
136            }
137        };
138
139        let mut idx = start;
140        let mut last_match_end = None;
141
142        while let Some((pos, ch)) = input[idx..].char_indices().next() {
143            // advance over one identifier token
144            if !is_ident_char(ch) {
145                break;
146            }
147            // find end of identifier token starting at idx
148            let mut j = idx + pos;
149            while let Some((p, ch2)) = input[j..].char_indices().next() {
150                if !is_ident_char(ch2) {
151                    break;
152                }
153                j = j + p + ch2.len_utf8();
154            }
155
156            let candidate = &input[start..j];
157            let candidate = crate::lexer::collapse_line_continuations(candidate);
158            let candidate_ref = candidate.as_ref().trim();
159            // Check if this candidate matches any registered directive
160            let has_rule = if self.case_insensitive {
161                self.rules
162                    .keys()
163                    .any(|k| k.eq_ignore_ascii_case(candidate_ref))
164            } else {
165                self.rules.contains_key(candidate_ref)
166            };
167
168            if has_rule {
169                last_match_end = Some(j);
170            }
171
172            // advance idx past any whitespace following the identifier
173            idx = j;
174            if let Ok((remaining, _)) = crate::lexer::skip_space_and_comments(&input[idx..]) {
175                let consumed = input[idx..].len() - remaining.len();
176                idx += consumed;
177            }
178
179            // if next character starts an identifier, loop to extend candidate
180            if let Some((_, next_ch)) = input[idx..].char_indices().next() {
181                if is_ident_char(next_ch) {
182                    // check if prefix is registered; if so, continue to extend
183                    let prefix_candidate = input[start..idx].trim_end();
184                    let prefix_candidate =
185                        crate::lexer::collapse_line_continuations(prefix_candidate);
186                    let prefix_candidate_ref = prefix_candidate.as_ref().trim_end();
187                    // Check for prefixes
188                    let has_prefix = if self.case_insensitive {
189                        self.prefixes
190                            .iter()
191                            .any(|p| p.eq_ignore_ascii_case(prefix_candidate_ref))
192                            || self
193                                .rules
194                                .keys()
195                                .any(|k| k.eq_ignore_ascii_case(prefix_candidate_ref))
196                    } else {
197                        self.prefixes.contains(prefix_candidate_ref)
198                            || self.rules.contains_key(prefix_candidate_ref)
199                    };
200                    if has_prefix {
201                        continue;
202                    }
203                }
204            }
205
206            break;
207        }
208
209        let name_end = last_match_end
210            .ok_or_else(|| nom::Err::Error(nom::error::Error::new(input, ErrorKind::Tag)))?;
211
212        let raw_name = &input[start..name_end];
213        let normalized = crate::lexer::collapse_line_continuations(raw_name);
214        let normalized = if self.case_insensitive {
215            let lowered = normalized.as_ref().to_ascii_lowercase();
216            if lowered == normalized.as_ref() {
217                normalized
218            } else {
219                Cow::Owned(lowered)
220            }
221        } else {
222            normalized
223        };
224
225        let rest = &input[name_end..];
226
227        Ok((rest, normalized))
228    }
229}
230
231impl Default for DirectiveRegistry {
232    fn default() -> Self {
233        DirectiveRegistry::builder()
234            .register_generic("parallel")
235            .build()
236    }
237}
238
239pub struct DirectiveRegistryBuilder {
240    rules: HashMap<&'static str, DirectiveRule>,
241    prefixes: HashSet<String>,
242    default_rule: DirectiveRule,
243    case_insensitive: bool,
244}
245
246impl DirectiveRegistryBuilder {
247    pub fn new() -> Self {
248        Self {
249            rules: HashMap::new(),
250            prefixes: HashSet::new(),
251            default_rule: DirectiveRule::Generic,
252            case_insensitive: false,
253        }
254    }
255
256    pub fn register_generic(mut self, name: &'static str) -> Self {
257        self.insert_rule(name, DirectiveRule::Generic);
258        self
259    }
260
261    pub fn register_custom(mut self, name: &'static str, parser: DirectiveParserFn) -> Self {
262        self.insert_rule(name, DirectiveRule::Custom(parser));
263        self
264    }
265
266    pub fn with_default_rule(mut self, rule: DirectiveRule) -> Self {
267        self.default_rule = rule;
268        self
269    }
270
271    pub fn with_case_insensitive(mut self, enabled: bool) -> Self {
272        self.case_insensitive = enabled;
273        self
274    }
275
276    pub fn build(self) -> DirectiveRegistry {
277        DirectiveRegistry {
278            rules: self.rules,
279            prefixes: self.prefixes,
280            default_rule: self.default_rule,
281            case_insensitive: self.case_insensitive,
282        }
283    }
284
285    fn insert_rule(&mut self, name: &'static str, rule: DirectiveRule) {
286        self.rules.insert(name, rule);
287        self.register_prefixes(name);
288    }
289
290    fn register_prefixes(&mut self, name: &'static str) {
291        let segments = name.split_whitespace().collect::<Vec<_>>();
292        if segments.len() <= 1 {
293            return;
294        }
295
296        let mut current = String::new();
297        for segment in segments.iter().take(segments.len() - 1) {
298            if !current.is_empty() {
299                current.push(' ');
300            }
301            current.push_str(segment);
302            self.prefixes.insert(current.clone());
303        }
304    }
305}
306
307impl Default for DirectiveRegistryBuilder {
308    fn default() -> Self {
309        Self::new()
310    }
311}
312
313// legacy byte-based identifier checker removed in favor of char-based helper
314
315#[cfg(test)]
316mod tests {
317    use super::*;
318    use crate::parser::ClauseKind;
319    use nom::bytes::complete::tag;
320
321    #[test]
322    fn parses_generic_directive_with_clauses() {
323        let clause_registry = ClauseRegistry::default();
324        let registry = DirectiveRegistry::default();
325
326        let (rest, directive) = registry
327            .parse_with_name("parallel".into(), " private(x, y) nowait", &clause_registry)
328            .expect("parsing should succeed");
329
330        assert_eq!(rest, "");
331        assert_eq!(directive.name, "parallel");
332        assert_eq!(directive.clauses.len(), 2);
333        assert_eq!(directive.clauses[0].name, "private");
334        assert_eq!(
335            directive.clauses[0].kind,
336            ClauseKind::Parenthesized("x, y".into())
337        );
338        assert_eq!(directive.clauses[1].name, "nowait");
339        assert_eq!(directive.clauses[1].kind, ClauseKind::Bare);
340    }
341
342    #[test]
343    fn parses_longest_matching_name() {
344        let clause_registry = ClauseRegistry::default();
345        let registry = DirectiveRegistry::builder()
346            .register_generic("target teams")
347            .register_generic("target teams distribute")
348            .register_generic("target teams distribute parallel for")
349            .build();
350
351        let (rest, directive) = registry
352            .parse(
353                "target teams distribute parallel for private(a)",
354                &clause_registry,
355            )
356            .expect("parsing should succeed");
357
358        assert_eq!(rest, "");
359        assert_eq!(directive.name, "target teams distribute parallel for");
360        assert_eq!(directive.clauses.len(), 1);
361        assert_eq!(directive.clauses[0].name, "private");
362    }
363
364    fn parse_prefixed_directive<'a>(
365        name: Cow<'a, str>,
366        input: &'a str,
367        clause_registry: &ClauseRegistry,
368    ) -> IResult<&'a str, Directive<'a>> {
369        let (input, _) = tag("custom:")(input)?;
370        let (input, clauses) = clause_registry.parse_sequence(input)?;
371
372        Ok((input, Directive { name, clauses }))
373    }
374
375    #[test]
376    fn supports_custom_directive_rule() {
377        let clause_registry = ClauseRegistry::default();
378        let registry = DirectiveRegistry::builder()
379            .register_custom("target", parse_prefixed_directive)
380            .build();
381
382        let (rest, directive) = registry
383            .parse_with_name("target".into(), "custom: private(a)", &clause_registry)
384            .expect("parsing should succeed");
385
386        assert_eq!(rest, "");
387        assert_eq!(directive.name, "target");
388        assert_eq!(directive.clauses.len(), 1);
389        assert_eq!(directive.clauses[0].name, "private");
390        assert_eq!(
391            directive.clauses[0].kind,
392            ClauseKind::Parenthesized("a".into())
393        );
394    }
395
396    #[test]
397    fn directive_display_includes_all_clauses() {
398        let directive = Directive {
399            name: "parallel".into(),
400            clauses: vec![
401                Clause {
402                    name: "private".into(),
403                    kind: ClauseKind::Parenthesized("a, b".into()),
404                },
405                Clause {
406                    name: "nowait".into(),
407                    kind: ClauseKind::Bare,
408                },
409            ],
410        };
411
412        assert_eq!(
413            directive.to_string(),
414            "#pragma omp parallel private(a, b) nowait"
415        );
416        assert_eq!(
417            directive.to_pragma_string(),
418            "#pragma omp parallel private(a, b) nowait"
419        );
420    }
421
422    #[test]
423    fn directive_display_without_clauses() {
424        let directive = Directive {
425            name: "barrier".into(),
426            clauses: vec![],
427        };
428
429        assert_eq!(directive.to_string(), "#pragma omp barrier");
430        assert_eq!(directive.to_pragma_string(), "#pragma omp barrier");
431    }
432}