From c56621682ed7691d29cef9d2426a076da97cc017 Mon Sep 17 00:00:00 2001
From: Mislav Bradac <mislav.bradac@memgraph.io>
Date: Mon, 4 Sep 2017 16:03:17 +0200
Subject: [PATCH] Parse utf16 surrogate codepoints correctly

Reviewers: buda

Reviewed By: buda

Subscribers: pullbot

Differential Revision: https://phabricator.memgraph.io/D745
---
 src/query/common.cpp    | 44 ++++++++++++++++++++++++++++++++++++-----
 tests/unit/stripped.cpp | 16 +++++++++++++++
 2 files changed, 55 insertions(+), 5 deletions(-)

diff --git a/src/query/common.cpp b/src/query/common.cpp
index 95b5de321..5d0115a96 100644
--- a/src/query/common.cpp
+++ b/src/query/common.cpp
@@ -38,10 +38,40 @@ std::string ParseStringLiteral(const std::string &s) {
       return converter.to_bytes(t);
     } else if (j - i >= kShortUnicodeLength + 1) {
       char16_t t = stoi(s.substr(i + 1, kShortUnicodeLength), 0, 16);
-      i += kShortUnicodeLength;
-      std::wstring_convert<std::codecvt_utf8_utf16<char16_t>, char16_t>
-          converter;
-      return converter.to_bytes(t);
+      if (t >= 0xD800 && t <= 0xDBFF) {
+        // t is high surrogate pair. Expect one more utf16 codepoint.
+        j = i + kShortUnicodeLength + 1;
+        if (j >= static_cast<int>(s.size()) - 1 || s[j] != '\\') {
+          throw SemanticException("Invalid utf codepoint");
+        }
+        ++j;
+        if (j >= static_cast<int>(s.size()) - 1 ||
+            (s[j] != 'u' && s[j] != 'U')) {
+          throw SemanticException("Invalid utf codepoint");
+        }
+        ++j;
+        int k = j;
+        while (k < static_cast<int>(s.size()) - 1 &&
+               k < j + kShortUnicodeLength && isxdigit(s[k])) {
+          ++k;
+        }
+        if (k != j + kShortUnicodeLength) {
+          throw SemanticException("Invalid utf codepoint");
+        }
+        char16_t surrogates[3] = {t,
+                                  static_cast<char16_t>(stoi(
+                                      s.substr(j, kShortUnicodeLength), 0, 16)),
+                                  0};
+        i += kShortUnicodeLength + 2 + kShortUnicodeLength;
+        std::wstring_convert<std::codecvt_utf8_utf16<char16_t>, char16_t>
+            converter;
+        return converter.to_bytes(surrogates);
+      } else {
+        i += kShortUnicodeLength;
+        std::wstring_convert<std::codecvt_utf8_utf16<char16_t>, char16_t>
+            converter;
+        return converter.to_bytes(t);
+      }
     } else {
       // This should never happen, except grammar changes and we don't notice
       // change in this production.
@@ -88,7 +118,11 @@ std::string ParseStringLiteral(const std::string &s) {
           break;
         case 'U':
         case 'u':
-          unescaped += EncodeEscapedUnicodeCodepoint(s, i);
+          try {
+            unescaped += EncodeEscapedUnicodeCodepoint(s, i);
+          } catch (const std::range_error &) {
+            throw SemanticException("Invalid utf codepoint");
+          }
           break;
         default:
           // This should never happen, except grammar changes and we don't
diff --git a/tests/unit/stripped.cpp b/tests/unit/stripped.cpp
index a914bcab0..a8837c551 100644
--- a/tests/unit/stripped.cpp
+++ b/tests/unit/stripped.cpp
@@ -131,6 +131,22 @@ TEST(QueryStripper, StringLiteral4) {
   EXPECT_EQ(stripped.query(), "return " + kStrippedStringToken);
 }
 
+TEST(QueryStripper, HighSurrogateAlone) {
+  ASSERT_THROW(StrippedQuery("RETURN '\\udeeb'"), SemanticException);
+}
+
+TEST(QueryStripper, LowSurrogateAlone) {
+  ASSERT_THROW(StrippedQuery("RETURN '\\ud83d'"), SemanticException);
+}
+
+TEST(QueryStripper, Surrogates) {
+  StrippedQuery stripped("RETURN '\\ud83d\\udeeb'");
+  EXPECT_EQ(stripped.literals().size(), 1);
+  EXPECT_EQ(stripped.literals().At(0).second.Value<std::string>(),
+            u8"\U0001f6eb");
+  EXPECT_EQ(stripped.query(), "return " + kStrippedStringToken);
+}
+
 TEST(QueryStripper, StringLiteralIllegalEscapedSequence) {
   EXPECT_THROW(StrippedQuery("RETURN 'so\\x'"), LexingException);
   EXPECT_THROW(StrippedQuery("RETURN 'so\\uabc'"), LexingException);