1use std::{
2 borrow::Cow,
3 collections::{HashMap, HashSet},
4 fmt,
5};
6
7use nom::{error::ErrorKind, IResult};
8
9use super::clause::{Clause, ClauseRegistry};
10
11type DirectiveParserFn =
12 for<'a> fn(Cow<'a, str>, &'a str, &ClauseRegistry) -> IResult<&'a str, Directive<'a>>;
13
14#[derive(Debug, PartialEq, Eq)]
15pub struct Directive<'a> {
16 pub name: Cow<'a, str>,
17 pub clauses: Vec<Clause<'a>>,
18}
19
20impl<'a> Directive<'a> {
21 pub fn to_pragma_string(&self) -> String {
22 self.to_string()
23 }
24}
25
26impl<'a> fmt::Display for Directive<'a> {
27 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
28 write!(f, "#pragma omp {}", self.name.as_ref())?;
29 if !self.clauses.is_empty() {
30 write!(f, " ")?;
31 for (idx, clause) in self.clauses.iter().enumerate() {
32 if idx > 0 {
33 write!(f, " ")?;
34 }
35 write!(f, "{}", clause)?;
36 }
37 }
38 Ok(())
39 }
40}
41
42#[derive(Clone, Copy)]
43pub enum DirectiveRule {
44 Generic,
45 Custom(DirectiveParserFn),
46}
47
48impl DirectiveRule {
49 fn parse<'a>(
50 self,
51 name: Cow<'a, str>,
52 input: &'a str,
53 clause_registry: &ClauseRegistry,
54 ) -> IResult<&'a str, Directive<'a>> {
55 match self {
56 DirectiveRule::Generic => {
57 let (input, clauses) = clause_registry.parse_sequence(input)?;
58 Ok((input, Directive { name, clauses }))
59 }
60 DirectiveRule::Custom(parser) => parser(name, input, clause_registry),
61 }
62 }
63}
64
65pub struct DirectiveRegistry {
66 rules: HashMap<&'static str, DirectiveRule>,
67 prefixes: HashSet<String>,
68 default_rule: DirectiveRule,
69 case_insensitive: bool,
70}
71
72impl DirectiveRegistry {
73 pub fn builder() -> DirectiveRegistryBuilder {
74 DirectiveRegistryBuilder::new()
75 }
76
77 pub fn with_case_insensitive(mut self, enabled: bool) -> Self {
78 self.case_insensitive = enabled;
79 self
80 }
81
82 pub fn parse<'a>(
83 &self,
84 input: &'a str,
85 clause_registry: &ClauseRegistry,
86 ) -> IResult<&'a str, Directive<'a>> {
87 let (rest, name) = self.lex_name(input)?;
88 self.parse_with_name(name, rest, clause_registry)
89 }
90
91 pub fn parse_with_name<'a>(
92 &self,
93 name: Cow<'a, str>,
94 input: &'a str,
95 clause_registry: &ClauseRegistry,
96 ) -> IResult<&'a str, Directive<'a>> {
97 let lookup_name = name.as_ref();
99 let rule = if self.case_insensitive {
100 self.rules
106 .iter()
107 .find(|(k, _)| k.eq_ignore_ascii_case(lookup_name))
108 .map(|(_, v)| *v)
109 .unwrap_or(self.default_rule)
110 } else {
111 self.rules
113 .get(lookup_name)
114 .copied()
115 .unwrap_or(self.default_rule)
116 };
117
118 rule.parse(name, input, clause_registry)
119 }
120
121 fn lex_name<'a>(&self, input: &'a str) -> IResult<&'a str, Cow<'a, str>> {
122 use crate::lexer::is_identifier_char as is_ident_char;
123
124 let mut chars = input.char_indices();
125 let start = loop {
127 match chars.next() {
128 Some((_, ch)) if ch.is_whitespace() => continue,
129 Some((idx, _)) => break idx,
130 None => {
131 return Err(nom::Err::Error(nom::error::Error::new(
132 input,
133 ErrorKind::Tag,
134 )))
135 }
136 }
137 };
138
139 let mut idx = start;
140 let mut last_match_end = None;
141
142 while let Some((pos, ch)) = input[idx..].char_indices().next() {
143 if !is_ident_char(ch) {
145 break;
146 }
147 let mut j = idx + pos;
149 while let Some((p, ch2)) = input[j..].char_indices().next() {
150 if !is_ident_char(ch2) {
151 break;
152 }
153 j = j + p + ch2.len_utf8();
154 }
155
156 let candidate = &input[start..j];
157 let candidate = crate::lexer::collapse_line_continuations(candidate);
158 let candidate_ref = candidate.as_ref().trim();
159 let has_rule = if self.case_insensitive {
161 self.rules
162 .keys()
163 .any(|k| k.eq_ignore_ascii_case(candidate_ref))
164 } else {
165 self.rules.contains_key(candidate_ref)
166 };
167
168 if has_rule {
169 last_match_end = Some(j);
170 }
171
172 idx = j;
174 if let Ok((remaining, _)) = crate::lexer::skip_space_and_comments(&input[idx..]) {
175 let consumed = input[idx..].len() - remaining.len();
176 idx += consumed;
177 }
178
179 if let Some((_, next_ch)) = input[idx..].char_indices().next() {
181 if is_ident_char(next_ch) {
182 let prefix_candidate = input[start..idx].trim_end();
184 let prefix_candidate =
185 crate::lexer::collapse_line_continuations(prefix_candidate);
186 let prefix_candidate_ref = prefix_candidate.as_ref().trim_end();
187 let has_prefix = if self.case_insensitive {
189 self.prefixes
190 .iter()
191 .any(|p| p.eq_ignore_ascii_case(prefix_candidate_ref))
192 || self
193 .rules
194 .keys()
195 .any(|k| k.eq_ignore_ascii_case(prefix_candidate_ref))
196 } else {
197 self.prefixes.contains(prefix_candidate_ref)
198 || self.rules.contains_key(prefix_candidate_ref)
199 };
200 if has_prefix {
201 continue;
202 }
203 }
204 }
205
206 break;
207 }
208
209 let name_end = last_match_end
210 .ok_or_else(|| nom::Err::Error(nom::error::Error::new(input, ErrorKind::Tag)))?;
211
212 let raw_name = &input[start..name_end];
213 let normalized = crate::lexer::collapse_line_continuations(raw_name);
214 let normalized = if self.case_insensitive {
215 let lowered = normalized.as_ref().to_ascii_lowercase();
216 if lowered == normalized.as_ref() {
217 normalized
218 } else {
219 Cow::Owned(lowered)
220 }
221 } else {
222 normalized
223 };
224
225 let rest = &input[name_end..];
226
227 Ok((rest, normalized))
228 }
229}
230
231impl Default for DirectiveRegistry {
232 fn default() -> Self {
233 DirectiveRegistry::builder()
234 .register_generic("parallel")
235 .build()
236 }
237}
238
239pub struct DirectiveRegistryBuilder {
240 rules: HashMap<&'static str, DirectiveRule>,
241 prefixes: HashSet<String>,
242 default_rule: DirectiveRule,
243 case_insensitive: bool,
244}
245
246impl DirectiveRegistryBuilder {
247 pub fn new() -> Self {
248 Self {
249 rules: HashMap::new(),
250 prefixes: HashSet::new(),
251 default_rule: DirectiveRule::Generic,
252 case_insensitive: false,
253 }
254 }
255
256 pub fn register_generic(mut self, name: &'static str) -> Self {
257 self.insert_rule(name, DirectiveRule::Generic);
258 self
259 }
260
261 pub fn register_custom(mut self, name: &'static str, parser: DirectiveParserFn) -> Self {
262 self.insert_rule(name, DirectiveRule::Custom(parser));
263 self
264 }
265
266 pub fn with_default_rule(mut self, rule: DirectiveRule) -> Self {
267 self.default_rule = rule;
268 self
269 }
270
271 pub fn with_case_insensitive(mut self, enabled: bool) -> Self {
272 self.case_insensitive = enabled;
273 self
274 }
275
276 pub fn build(self) -> DirectiveRegistry {
277 DirectiveRegistry {
278 rules: self.rules,
279 prefixes: self.prefixes,
280 default_rule: self.default_rule,
281 case_insensitive: self.case_insensitive,
282 }
283 }
284
285 fn insert_rule(&mut self, name: &'static str, rule: DirectiveRule) {
286 self.rules.insert(name, rule);
287 self.register_prefixes(name);
288 }
289
290 fn register_prefixes(&mut self, name: &'static str) {
291 let segments = name.split_whitespace().collect::<Vec<_>>();
292 if segments.len() <= 1 {
293 return;
294 }
295
296 let mut current = String::new();
297 for segment in segments.iter().take(segments.len() - 1) {
298 if !current.is_empty() {
299 current.push(' ');
300 }
301 current.push_str(segment);
302 self.prefixes.insert(current.clone());
303 }
304 }
305}
306
307impl Default for DirectiveRegistryBuilder {
308 fn default() -> Self {
309 Self::new()
310 }
311}
312
313#[cfg(test)]
316mod tests {
317 use super::*;
318 use crate::parser::ClauseKind;
319 use nom::bytes::complete::tag;
320
321 #[test]
322 fn parses_generic_directive_with_clauses() {
323 let clause_registry = ClauseRegistry::default();
324 let registry = DirectiveRegistry::default();
325
326 let (rest, directive) = registry
327 .parse_with_name("parallel".into(), " private(x, y) nowait", &clause_registry)
328 .expect("parsing should succeed");
329
330 assert_eq!(rest, "");
331 assert_eq!(directive.name, "parallel");
332 assert_eq!(directive.clauses.len(), 2);
333 assert_eq!(directive.clauses[0].name, "private");
334 assert_eq!(
335 directive.clauses[0].kind,
336 ClauseKind::Parenthesized("x, y".into())
337 );
338 assert_eq!(directive.clauses[1].name, "nowait");
339 assert_eq!(directive.clauses[1].kind, ClauseKind::Bare);
340 }
341
342 #[test]
343 fn parses_longest_matching_name() {
344 let clause_registry = ClauseRegistry::default();
345 let registry = DirectiveRegistry::builder()
346 .register_generic("target teams")
347 .register_generic("target teams distribute")
348 .register_generic("target teams distribute parallel for")
349 .build();
350
351 let (rest, directive) = registry
352 .parse(
353 "target teams distribute parallel for private(a)",
354 &clause_registry,
355 )
356 .expect("parsing should succeed");
357
358 assert_eq!(rest, "");
359 assert_eq!(directive.name, "target teams distribute parallel for");
360 assert_eq!(directive.clauses.len(), 1);
361 assert_eq!(directive.clauses[0].name, "private");
362 }
363
364 fn parse_prefixed_directive<'a>(
365 name: Cow<'a, str>,
366 input: &'a str,
367 clause_registry: &ClauseRegistry,
368 ) -> IResult<&'a str, Directive<'a>> {
369 let (input, _) = tag("custom:")(input)?;
370 let (input, clauses) = clause_registry.parse_sequence(input)?;
371
372 Ok((input, Directive { name, clauses }))
373 }
374
375 #[test]
376 fn supports_custom_directive_rule() {
377 let clause_registry = ClauseRegistry::default();
378 let registry = DirectiveRegistry::builder()
379 .register_custom("target", parse_prefixed_directive)
380 .build();
381
382 let (rest, directive) = registry
383 .parse_with_name("target".into(), "custom: private(a)", &clause_registry)
384 .expect("parsing should succeed");
385
386 assert_eq!(rest, "");
387 assert_eq!(directive.name, "target");
388 assert_eq!(directive.clauses.len(), 1);
389 assert_eq!(directive.clauses[0].name, "private");
390 assert_eq!(
391 directive.clauses[0].kind,
392 ClauseKind::Parenthesized("a".into())
393 );
394 }
395
396 #[test]
397 fn directive_display_includes_all_clauses() {
398 let directive = Directive {
399 name: "parallel".into(),
400 clauses: vec![
401 Clause {
402 name: "private".into(),
403 kind: ClauseKind::Parenthesized("a, b".into()),
404 },
405 Clause {
406 name: "nowait".into(),
407 kind: ClauseKind::Bare,
408 },
409 ],
410 };
411
412 assert_eq!(
413 directive.to_string(),
414 "#pragma omp parallel private(a, b) nowait"
415 );
416 assert_eq!(
417 directive.to_pragma_string(),
418 "#pragma omp parallel private(a, b) nowait"
419 );
420 }
421
422 #[test]
423 fn directive_display_without_clauses() {
424 let directive = Directive {
425 name: "barrier".into(),
426 clauses: vec![],
427 };
428
429 assert_eq!(directive.to_string(), "#pragma omp barrier");
430 assert_eq!(directive.to_pragma_string(), "#pragma omp barrier");
431 }
432}