1use std::borrow::Cow;
18
19use nom::IResult;
20
21use nom::bytes::complete::{tag, take_while1};
26
27#[derive(Debug, Clone, Copy, PartialEq, Eq)]
29pub enum Language {
30 C,
32 FortranFree,
34 FortranFixed,
36}
37
38impl Default for Language {
39 fn default() -> Self {
40 Language::C
41 }
42}
43
44pub fn is_identifier_char(c: char) -> bool {
51 c.is_alphanumeric() || c == '_'
52}
53
54pub fn normalize_fortran_identifier(s: &str) -> String {
56 s.to_lowercase()
57}
58
59pub fn lex_pragma(input: &str) -> IResult<&str, &str> {
67 tag("#pragma")(input)
68}
69
70pub fn lex_omp(input: &str) -> IResult<&str, &str> {
72 tag("omp")(input)
73}
74
75pub fn lex_fortran_free_sentinel(input: &str) -> IResult<&str, &str> {
82 let (after_space, _) = skip_space_and_comments(input)?;
84
85 let matches = after_space
87 .get(..5)
88 .map_or(false, |s| s.eq_ignore_ascii_case("!$omp"));
89
90 if matches {
91 Ok((&after_space[5..], &after_space[..5]))
92 } else {
93 Err(nom::Err::Error(nom::error::Error::new(
94 input,
95 nom::error::ErrorKind::Tag,
96 )))
97 }
98}
99
100pub fn lex_fortran_fixed_sentinel(input: &str) -> IResult<&str, &str> {
108 let (after_space, _) = skip_space_and_comments(input)?;
110
111 let first_5 = after_space.get(..5);
113 let matches = first_5.map_or(false, |s| {
114 s.eq_ignore_ascii_case("!$omp")
115 || s.eq_ignore_ascii_case("c$omp")
116 || s.eq_ignore_ascii_case("*$omp")
117 });
118
119 if matches {
120 Ok((&after_space[5..], &after_space[..5]))
121 } else {
122 Err(nom::Err::Error(nom::error::Error::new(
123 input,
124 nom::error::ErrorKind::Tag,
125 )))
126 }
127}
128
129fn lex_identifier(input: &str) -> IResult<&str, &str> {
137 take_while1(is_identifier_char)(input)
138}
139
140pub fn lex_directive(input: &str) -> IResult<&str, &str> {
142 lex_identifier(input)
143}
144
145pub fn lex_clause(input: &str) -> IResult<&str, &str> {
147 lex_identifier(input)
148}
149
150pub fn skip_space_and_comments(input: &str) -> IResult<&str, &str> {
157 let mut i = 0;
158 let bytes = input.as_bytes();
159 let len = bytes.len();
160
161 while i < len {
162 if let Some(next_idx) = skip_c_line_continuation(input, i) {
164 i = next_idx;
165 continue;
166 }
167
168 if let Some(next_idx) = skip_fortran_continuation(input, i) {
170 i = next_idx;
171 continue;
172 }
173
174 if bytes[i].is_ascii_whitespace() {
179 let ch = input[i..].chars().next().unwrap();
184 i += ch.len_utf8();
185 continue;
186 }
187
188 if i + 1 < len && &input[i..i + 2] == "/*" {
190 if let Some(end) = input[i + 2..].find("*/") {
191 i += 2 + end + 2;
192 continue;
193 } else {
194 i = len;
196 break;
197 }
198 }
199
200 if i + 1 < len && &input[i..i + 2] == "//" {
202 if let Some(end) = input[i + 2..].find('\n') {
203 i += 2 + end + 1;
204 } else {
205 i = len;
206 }
207 continue;
208 }
209
210 break;
211 }
212
213 Ok((&input[i..], &input[..0]))
215}
216
217pub fn skip_space1_and_comments(input: &str) -> IResult<&str, &str> {
219 let (rest, _) = skip_space_and_comments(input)?;
220
221 if rest.len() == input.len() {
225 Err(nom::Err::Error(nom::error::Error::new(
226 input,
227 nom::error::ErrorKind::Space,
228 )))
229 } else {
230 Ok((rest, &input[..0]))
231 }
232}
233
234pub fn lex_identifier_token(input: &str) -> IResult<&str, &str> {
236 lex_identifier(input)
237}
238
239#[inline]
246fn has_continuation_markers(bytes: &[u8]) -> bool {
247 for &b in bytes {
248 if b == b'\\' || b == b'&' {
249 return true;
250 }
251 }
252 false
253}
254
255#[doc(hidden)]
260pub fn collapse_line_continuations<'a>(input: &'a str) -> Cow<'a, str> {
261 let bytes = input.as_bytes();
271 if !has_continuation_markers(bytes) {
272 return Cow::Borrowed(input);
273 }
274
275 let mut output = String::with_capacity(input.len());
276 let mut idx = 0;
277 let len = bytes.len();
278 let mut changed = false;
279
280 while idx < len {
281 if bytes[idx] == b'\\' {
282 let mut next = idx + 1;
283 while next < len && matches!(bytes[next], b' ' | b'\t') {
284 next += 1;
285 }
286 if next < len && (bytes[next] == b'\n' || bytes[next] == b'\r') {
287 changed = true;
288 if bytes[next] == b'\r' {
289 next += 1;
290 if next < len && bytes[next] == b'\n' {
291 next += 1;
292 }
293 } else {
294 next += 1;
295 }
296 while next < len && matches!(bytes[next], b' ' | b'\t') {
297 next += 1;
298 }
299 if !output.is_empty() && !output.ends_with(|c: char| c.is_whitespace()) {
302 output.push(' ');
303 }
304 idx = next;
305 continue;
306 }
307 } else if bytes[idx] == b'&' {
308 if let Some(next) = skip_fortran_continuation(input, idx) {
309 changed = true;
310 if !output.is_empty() && !output.ends_with(|c: char| c.is_whitespace()) {
312 output.push(' ');
313 }
314 idx = next;
315 continue;
316 }
317 }
318
319 let ch = input[idx..].chars().next().unwrap();
320 output.push(ch);
321 idx += ch.len_utf8();
322 }
323
324 if changed {
325 Cow::Owned(output)
326 } else {
327 Cow::Borrowed(input)
328 }
329}
330
331fn skip_c_line_continuation(input: &str, idx: usize) -> Option<usize> {
332 let bytes = input.as_bytes();
333 let len = bytes.len();
334 if idx >= len || bytes[idx] != b'\\' {
335 return None;
336 }
337
338 let mut next = idx + 1;
339 while next < len && matches!(bytes[next], b' ' | b'\t') {
340 next += 1;
341 }
342
343 if next >= len {
344 return Some(len);
345 }
346
347 match bytes[next] {
348 b'\n' => {
349 next += 1;
350 }
351 b'\r' => {
352 next += 1;
353 if next < len && bytes[next] == b'\n' {
354 next += 1;
355 }
356 }
357 _ => return None,
358 }
359
360 while next < len && matches!(bytes[next], b' ' | b'\t') {
361 next += 1;
362 }
363
364 Some(next)
365}
366
367fn skip_fortran_continuation(input: &str, idx: usize) -> Option<usize> {
368 let bytes = input.as_bytes();
369 let len = bytes.len();
370 if idx >= len || bytes[idx] != b'&' {
371 return None;
372 }
373
374 let mut next = idx + 1;
375
376 while next < len {
377 match bytes[next] {
378 b' ' | b'\t' => next += 1,
379 b'!' => {
380 next += 1;
381 while next < len && bytes[next] != b'\n' && bytes[next] != b'\r' {
382 next += 1;
383 }
384 break;
385 }
386 b'\n' | b'\r' => break,
387 _ => return None,
388 }
389 }
390
391 if next >= len {
392 return Some(len);
393 }
394
395 if bytes[next] == b'\r' {
396 next += 1;
397 if next < len && bytes[next] == b'\n' {
398 next += 1;
399 }
400 } else if bytes[next] == b'\n' {
401 next += 1;
402 } else {
403 return None;
404 }
405
406 while next < len {
407 match bytes[next] {
408 b' ' | b'\t' => next += 1,
409 b'\r' | b'\n' => {
410 next += 1;
411 }
412 _ => break,
413 }
414 }
415
416 if let Some(len_sent) = match_fortran_sentinel(&input[next..]) {
417 next += len_sent;
418 while next < len && matches!(bytes[next], b' ' | b'\t') {
419 next += 1;
420 }
421 }
422
423 if next < len && bytes[next] == b'&' {
424 next += 1;
425 while next < len && matches!(bytes[next], b' ' | b'\t') {
426 next += 1;
427 }
428 }
429
430 Some(next)
431}
432
433fn match_fortran_sentinel(input: &str) -> Option<usize> {
434 let candidates = ["!$omp", "c$omp", "*$omp"];
435 for candidate in candidates {
436 if input.len() >= candidate.len()
437 && input[..candidate.len()].eq_ignore_ascii_case(candidate)
438 {
439 return Some(candidate.len());
440 }
441 }
442 None
443}
444
445#[cfg(test)]
446mod tests {
447 use super::*;
448
449 #[test]
450 fn parses_pragma_keyword() {
451 let result = lex_pragma("#pragma omp parallel");
452 assert!(result.is_ok());
453
454 let (remaining, matched) = result.unwrap();
458 assert_eq!(matched, "#pragma");
459 assert_eq!(remaining, " omp parallel");
460 }
461
462 #[test]
463 fn parses_omp_keyword() {
464 let (remaining, matched) = lex_omp("omp parallel").unwrap();
465 assert_eq!(matched, "omp");
466 assert_eq!(remaining, " parallel");
467 }
468
469 #[test]
470 fn parses_identifiers() {
471 let (rest, name) = lex_identifier("parallel private").unwrap();
472 assert_eq!(name, "parallel");
473 assert_eq!(rest, " private");
474
475 let (rest2, name2) = lex_identifier("private_data(x)").unwrap();
476 assert_eq!(name2, "private_data");
477 assert_eq!(rest2, "(x)");
478 }
479
480 #[test]
481 fn identifier_requires_alphanumeric() {
482 let result = lex_identifier("(invalid");
484 assert!(result.is_err());
485 }
486
487 #[test]
488 fn skips_whitespace() {
489 let (rest, _) = skip_space_and_comments(" hello").unwrap();
490 assert_eq!(rest, "hello");
491
492 let (rest, _) = skip_space_and_comments("\t\n world").unwrap();
493 assert_eq!(rest, "world");
494 }
495
496 #[test]
497 fn skips_c_style_comments() {
498 let (rest, _) = skip_space_and_comments("/* comment */ code").unwrap();
499 assert_eq!(rest, "code");
500
501 let (rest, _) = skip_space_and_comments("/* multi\nline\ncomment */ after").unwrap();
502 assert_eq!(rest, "after");
503 }
504
505 #[test]
506 fn skips_cpp_style_comments() {
507 let (rest, _) = skip_space_and_comments("// comment\ncode").unwrap();
508 assert_eq!(rest, "code");
509 }
510
511 #[test]
512 fn skips_mixed_whitespace_and_comments() {
513 let input = " /* comment1 */ \n // comment2\n code";
514 let (rest, _) = skip_space_and_comments(input).unwrap();
515 assert_eq!(rest, "code");
516 }
517
518 #[test]
519 fn skip_space1_requires_whitespace() {
520 let result = skip_space1_and_comments("no_space");
521 assert!(result.is_err());
522
523 let result = skip_space1_and_comments(" has_space");
524 assert!(result.is_ok());
525 }
526
527 #[test]
528 fn skip_space_handles_c_line_continuations() {
529 let (rest, _) = skip_space_and_comments("\\\n default(none)").unwrap();
530 assert_eq!(rest, "default(none)");
531 }
532
533 #[test]
534 fn skip_space_handles_fortran_continuations() {
535 let input = "&\n!$omp private(i, j)";
536 let (rest, _) = skip_space_and_comments(input).unwrap();
537 assert_eq!(rest, "private(i, j)");
538 }
539
540 #[test]
541 fn collapse_line_continuations_removes_c_backslash() {
542 let collapsed = collapse_line_continuations(concat!("a, \\\n", " b"));
543 assert_eq!(collapsed.as_ref(), "a, b");
544 }
545
546 #[test]
547 fn collapse_line_continuations_removes_fortran_ampersand() {
548 let input = "items( i, &\n!$omp& j )";
549 let collapsed = collapse_line_continuations(input);
550 assert_eq!(collapsed.as_ref(), "items( i, j )");
551 }
552
553 #[test]
554 fn parses_fortran_free_sentinel() {
555 let (rest, matched) = lex_fortran_free_sentinel("!$OMP parallel").unwrap();
556 assert_eq!(matched, "!$OMP");
557 assert_eq!(rest, " parallel");
558
559 let (rest, matched) = lex_fortran_free_sentinel("!$omp PARALLEL").unwrap();
561 assert_eq!(matched, "!$omp");
562 assert_eq!(rest, " PARALLEL");
563 }
564
565 #[test]
566 fn parses_fortran_fixed_sentinel() {
567 let (rest, matched) = lex_fortran_fixed_sentinel("!$OMP parallel").unwrap();
568 assert_eq!(matched, "!$OMP");
569 assert_eq!(rest, " parallel");
570
571 let (rest, matched) = lex_fortran_fixed_sentinel("C$OMP parallel").unwrap();
572 assert_eq!(matched, "C$OMP");
573 assert_eq!(rest, " parallel");
574
575 let (rest, matched) = lex_fortran_fixed_sentinel("c$omp PARALLEL").unwrap();
577 assert_eq!(matched, "c$omp");
578 assert_eq!(rest, " PARALLEL");
579 }
580
581 #[test]
582 fn normalizes_fortran_identifiers() {
583 assert_eq!(normalize_fortran_identifier("PARALLEL"), "parallel");
584 assert_eq!(normalize_fortran_identifier("Private"), "private");
585 assert_eq!(normalize_fortran_identifier("num_threads"), "num_threads");
586 }
587
588 #[test]
589 fn optimized_single_pass_no_markers() {
590 let inputs = vec![
592 "parallel",
593 "parallel for private(i)",
594 "target teams distribute",
595 "simd reduction(+:sum) private(i,j,k)",
596 ];
597
598 for input in inputs {
599 let result = collapse_line_continuations(input);
600 assert!(matches!(result, Cow::Borrowed(_)));
602 assert_eq!(result.as_ref(), input);
603 }
604 }
605
606 #[test]
607 fn single_pass_with_markers() {
608 let has_backslash = "parallel \\\n num_threads(4)";
610 let has_ampersand = "parallel do &\n!$omp private(i)";
611
612 let r1 = collapse_line_continuations(has_backslash);
614 let r2 = collapse_line_continuations(has_ampersand);
615
616 assert!(matches!(r1, Cow::Owned(_)));
617 assert!(matches!(r2, Cow::Owned(_)));
618 }
619}