diff --git a/src/query/common.cpp b/src/query/common.cpp index 95b5de321..5d0115a96 100644 --- a/src/query/common.cpp +++ b/src/query/common.cpp @@ -38,10 +38,40 @@ std::string ParseStringLiteral(const std::string &s) { return converter.to_bytes(t); } else if (j - i >= kShortUnicodeLength + 1) { char16_t t = stoi(s.substr(i + 1, kShortUnicodeLength), 0, 16); - i += kShortUnicodeLength; - std::wstring_convert, char16_t> - converter; - return converter.to_bytes(t); + if (t >= 0xD800 && t <= 0xDBFF) { + // t is high surrogate pair. Expect one more utf16 codepoint. + j = i + kShortUnicodeLength + 1; + if (j >= static_cast(s.size()) - 1 || s[j] != '\\') { + throw SemanticException("Invalid utf codepoint"); + } + ++j; + if (j >= static_cast(s.size()) - 1 || + (s[j] != 'u' && s[j] != 'U')) { + throw SemanticException("Invalid utf codepoint"); + } + ++j; + int k = j; + while (k < static_cast(s.size()) - 1 && + k < j + kShortUnicodeLength && isxdigit(s[k])) { + ++k; + } + if (k != j + kShortUnicodeLength) { + throw SemanticException("Invalid utf codepoint"); + } + char16_t surrogates[3] = {t, + static_cast(stoi( + s.substr(j, kShortUnicodeLength), 0, 16)), + 0}; + i += kShortUnicodeLength + 2 + kShortUnicodeLength; + std::wstring_convert, char16_t> + converter; + return converter.to_bytes(surrogates); + } else { + i += kShortUnicodeLength; + std::wstring_convert, char16_t> + converter; + return converter.to_bytes(t); + } } else { // This should never happen, except grammar changes and we don't notice // change in this production. @@ -88,7 +118,11 @@ std::string ParseStringLiteral(const std::string &s) { break; case 'U': case 'u': - unescaped += EncodeEscapedUnicodeCodepoint(s, i); + try { + unescaped += EncodeEscapedUnicodeCodepoint(s, i); + } catch (const std::range_error &) { + throw SemanticException("Invalid utf codepoint"); + } break; default: // This should never happen, except grammar changes and we don't diff --git a/tests/unit/stripped.cpp b/tests/unit/stripped.cpp index a914bcab0..a8837c551 100644 --- a/tests/unit/stripped.cpp +++ b/tests/unit/stripped.cpp @@ -131,6 +131,22 @@ TEST(QueryStripper, StringLiteral4) { EXPECT_EQ(stripped.query(), "return " + kStrippedStringToken); } +TEST(QueryStripper, HighSurrogateAlone) { + ASSERT_THROW(StrippedQuery("RETURN '\\udeeb'"), SemanticException); +} + +TEST(QueryStripper, LowSurrogateAlone) { + ASSERT_THROW(StrippedQuery("RETURN '\\ud83d'"), SemanticException); +} + +TEST(QueryStripper, Surrogates) { + StrippedQuery stripped("RETURN '\\ud83d\\udeeb'"); + EXPECT_EQ(stripped.literals().size(), 1); + EXPECT_EQ(stripped.literals().At(0).second.Value(), + u8"\U0001f6eb"); + EXPECT_EQ(stripped.query(), "return " + kStrippedStringToken); +} + TEST(QueryStripper, StringLiteralIllegalEscapedSequence) { EXPECT_THROW(StrippedQuery("RETURN 'so\\x'"), LexingException); EXPECT_THROW(StrippedQuery("RETURN 'so\\uabc'"), LexingException);