diff options
Diffstat (limited to 'clang/lib/Sema')
-rw-r--r-- | clang/lib/Sema/ScopeInfo.cpp | 3 | ||||
-rw-r--r-- | clang/lib/Sema/SemaCoroutine.cpp | 533 | ||||
-rw-r--r-- | clang/lib/Sema/SemaDecl.cpp | 2 | ||||
-rw-r--r-- | clang/lib/Sema/SemaExceptionSpec.cpp | 1 | ||||
-rw-r--r-- | clang/lib/Sema/TreeTransform.h | 137 |
5 files changed, 456 insertions, 220 deletions
diff --git a/clang/lib/Sema/ScopeInfo.cpp b/clang/lib/Sema/ScopeInfo.cpp index 58d44bacea9..8050889d71a 100644 --- a/clang/lib/Sema/ScopeInfo.cpp +++ b/clang/lib/Sema/ScopeInfo.cpp @@ -43,6 +43,9 @@ void FunctionScopeInfo::Clear() { SwitchStack.clear(); Returns.clear(); CoroutinePromise = nullptr; + NeedsCoroutineSuspends = true; + CoroutineSuspends.first = nullptr; + CoroutineSuspends.second = nullptr; CoroutineStmts.clear(); ErrorTrap.reset(); PossiblyUnreachableDiags.clear(); diff --git a/clang/lib/Sema/SemaCoroutine.cpp b/clang/lib/Sema/SemaCoroutine.cpp index 31bef09ee9a..9fec855bab2 100644 --- a/clang/lib/Sema/SemaCoroutine.cpp +++ b/clang/lib/Sema/SemaCoroutine.cpp @@ -21,6 +21,16 @@ using namespace clang; using namespace sema; +static bool lookupMember(Sema &S, const char *Name, CXXRecordDecl *RD, + SourceLocation Loc) { + DeclarationName DN = S.PP.getIdentifierInfo(Name); + LookupResult LR(S, DN, Loc, Sema::LookupMemberName); + // Suppress diagnostics when a private member is selected. The same warnings + // will be produced again when building the call. + LR.suppressDiagnostics(); + return S.LookupQualifiedName(LR, RD); +} + /// Look up the std::coroutine_traits<...>::promise_type for the given /// function type. static QualType lookupPromiseType(Sema &S, const FunctionProtoType *FnType, @@ -167,42 +177,48 @@ static bool isValidCoroutineContext(Sema &S, SourceLocation Loc, return !Diagnosed; } -/// Check that this is a context in which a coroutine suspension can appear. -static FunctionScopeInfo *checkCoroutineContext(Sema &S, SourceLocation Loc, - StringRef Keyword) { - if (!isValidCoroutineContext(S, Loc, Keyword)) - return nullptr; - - assert(isa<FunctionDecl>(S.CurContext) && "not in a function scope"); - auto *FD = cast<FunctionDecl>(S.CurContext); - auto *ScopeInfo = S.getCurFunction(); - assert(ScopeInfo && "missing function scope for function"); +static ExprResult buildOperatorCoawaitLookupExpr(Sema &SemaRef, Scope *S, + SourceLocation Loc) { + DeclarationName OpName = + SemaRef.Context.DeclarationNames.getCXXOperatorName(OO_Coawait); + LookupResult Operators(SemaRef, OpName, SourceLocation(), + Sema::LookupOperatorName); + SemaRef.LookupName(Operators, S); + + assert(!Operators.isAmbiguous() && "Operator lookup cannot be ambiguous"); + const auto &Functions = Operators.asUnresolvedSet(); + bool IsOverloaded = + Functions.size() > 1 || + (Functions.size() == 1 && isa<FunctionTemplateDecl>(*Functions.begin())); + Expr *CoawaitOp = UnresolvedLookupExpr::Create( + SemaRef.Context, /*NamingClass*/ nullptr, NestedNameSpecifierLoc(), + DeclarationNameInfo(OpName, Loc), /*RequiresADL*/ true, IsOverloaded, + Functions.begin(), Functions.end()); + assert(CoawaitOp); + return CoawaitOp; +} - // If we don't have a promise variable, build one now. - if (!ScopeInfo->CoroutinePromise) { - QualType T = FD->getType()->isDependentType() - ? S.Context.DependentTy - : lookupPromiseType( - S, FD->getType()->castAs<FunctionProtoType>(), - Loc, FD->getLocation()); - if (T.isNull()) - return nullptr; - - // Create and default-initialize the promise. - ScopeInfo->CoroutinePromise = - VarDecl::Create(S.Context, FD, FD->getLocation(), FD->getLocation(), - &S.PP.getIdentifierTable().get("__promise"), T, - S.Context.getTrivialTypeSourceInfo(T, Loc), SC_None); - S.CheckVariableDeclarationType(ScopeInfo->CoroutinePromise); - if (!ScopeInfo->CoroutinePromise->isInvalidDecl()) - S.ActOnUninitializedDecl(ScopeInfo->CoroutinePromise); - } +/// Build a call to 'operator co_await' if there is a suitable operator for +/// the given expression. +static ExprResult buildOperatorCoawaitCall(Sema &SemaRef, SourceLocation Loc, + Expr *E, + UnresolvedLookupExpr *Lookup) { + UnresolvedSet<16> Functions; + Functions.append(Lookup->decls_begin(), Lookup->decls_end()); + return SemaRef.CreateOverloadedUnaryOp(Loc, UO_Coawait, Functions, E); +} - return ScopeInfo; +static ExprResult buildOperatorCoawaitCall(Sema &SemaRef, Scope *S, + SourceLocation Loc, Expr *E) { + ExprResult R = buildOperatorCoawaitLookupExpr(SemaRef, S, Loc); + if (R.isInvalid()) + return ExprError(); + return buildOperatorCoawaitCall(SemaRef, Loc, E, + cast<UnresolvedLookupExpr>(R.get())); } static Expr *buildBuiltinCall(Sema &S, SourceLocation Loc, Builtin::ID Id, - MutableArrayRef<Expr *> CallArgs) { + MultiExprArg CallArgs) { StringRef Name = S.Context.BuiltinInfo.getName(Id); LookupResult R(S, &S.Context.Idents.get(Name), Loc, Sema::LookupOrdinaryName); S.LookupName(R, S.TUScope, /*AllowBuiltinCreation=*/true); @@ -221,15 +237,6 @@ static Expr *buildBuiltinCall(Sema &S, SourceLocation Loc, Builtin::ID Id, return Call.get(); } -/// Build a call to 'operator co_await' if there is a suitable operator for -/// the given expression. -static ExprResult buildOperatorCoawaitCall(Sema &SemaRef, Scope *S, - SourceLocation Loc, Expr *E) { - UnresolvedSet<16> Functions; - SemaRef.LookupOverloadedOperatorName(OO_Coawait, S, E->getType(), QualType(), - Functions); - return SemaRef.CreateOverloadedUnaryOp(Loc, UO_Coawait, Functions, E); -} struct ReadySuspendResumeResult { bool IsInvalid; @@ -237,8 +244,7 @@ struct ReadySuspendResumeResult { }; static ExprResult buildMemberCall(Sema &S, Expr *Base, SourceLocation Loc, - StringRef Name, - MutableArrayRef<Expr *> Args) { + StringRef Name, MultiExprArg Args) { DeclarationNameInfo NameInfo(&S.PP.getIdentifierTable().get(Name), Loc); // FIXME: Fix BuildMemberReferenceExpr to take a const CXXScopeSpec&. @@ -276,25 +282,174 @@ static ReadySuspendResumeResult buildCoawaitCalls(Sema &S, SourceLocation Loc, return Calls; } +static ExprResult buildPromiseCall(Sema &S, VarDecl *Promise, + SourceLocation Loc, StringRef Name, + MultiExprArg Args) { + + // Form a reference to the promise. + ExprResult PromiseRef = S.BuildDeclRefExpr( + Promise, Promise->getType().getNonReferenceType(), VK_LValue, Loc); + if (PromiseRef.isInvalid()) + return ExprError(); + + // Call 'yield_value', passing in E. + return buildMemberCall(S, PromiseRef.get(), Loc, Name, Args); +} + +VarDecl *Sema::buildCoroutinePromise(SourceLocation Loc) { + assert(isa<FunctionDecl>(CurContext) && "not in a function scope"); + auto *FD = cast<FunctionDecl>(CurContext); + + QualType T = + FD->getType()->isDependentType() + ? Context.DependentTy + : lookupPromiseType(*this, FD->getType()->castAs<FunctionProtoType>(), + Loc, FD->getLocation()); + if (T.isNull()) + return nullptr; + + auto *VD = VarDecl::Create(Context, FD, FD->getLocation(), FD->getLocation(), + &PP.getIdentifierTable().get("__promise"), T, + Context.getTrivialTypeSourceInfo(T, Loc), SC_None); + CheckVariableDeclarationType(VD); + if (VD->isInvalidDecl()) + return nullptr; + ActOnUninitializedDecl(VD); + assert(!VD->isInvalidDecl()); + return VD; +} + +/// Check that this is a context in which a coroutine suspension can appear. +static FunctionScopeInfo *checkCoroutineContext(Sema &S, SourceLocation Loc, + StringRef Keyword) { + if (!isValidCoroutineContext(S, Loc, Keyword)) + return nullptr; + + assert(isa<FunctionDecl>(S.CurContext) && "not in a function scope"); + auto *FD = cast<FunctionDecl>(S.CurContext); + + auto *ScopeInfo = S.getCurFunction(); + assert(ScopeInfo && "missing function scope for function"); + + if (ScopeInfo->CoroutinePromise) + return ScopeInfo; + + ScopeInfo->CoroutinePromise = S.buildCoroutinePromise(Loc); + if (!ScopeInfo->CoroutinePromise) + return nullptr; + + return ScopeInfo; +} + +static bool actOnCoroutineBodyStart(Sema &S, Scope *SC, SourceLocation KWLoc, + StringRef Keyword) { + if (!checkCoroutineContext(S, KWLoc, Keyword)) + return false; + auto *ScopeInfo = S.getCurFunction(); + assert(ScopeInfo->CoroutinePromise); + + // If we have existing coroutine statements then we have already built + // the initial and final suspend points. + if (!ScopeInfo->NeedsCoroutineSuspends) + return true; + + ScopeInfo->setNeedsCoroutineSuspends(false); + + auto *Fn = cast<FunctionDecl>(S.CurContext); + SourceLocation Loc = Fn->getLocation(); + // Build the initial suspend point + auto buildSuspends = [&](StringRef Name) mutable -> StmtResult { + ExprResult Suspend = + buildPromiseCall(S, ScopeInfo->CoroutinePromise, Loc, Name, None); + if (Suspend.isInvalid()) + return StmtError(); + Suspend = buildOperatorCoawaitCall(S, SC, Loc, Suspend.get()); + if (Suspend.isInvalid()) + return StmtError(); + Suspend = S.BuildResolvedCoawaitExpr(Loc, Suspend.get(), + /*IsImplicit*/ true); + Suspend = S.ActOnFinishFullExpr(Suspend.get()); + if (Suspend.isInvalid()) { + S.Diag(Loc, diag::note_coroutine_promise_call_implicitly_required) + << ((Name == "initial_suspend") ? 0 : 1); + S.Diag(KWLoc, diag::note_declared_coroutine_here) << Keyword; + return StmtError(); + } + return cast<Stmt>(Suspend.get()); + }; + + StmtResult InitSuspend = buildSuspends("initial_suspend"); + if (InitSuspend.isInvalid()) + return true; + + StmtResult FinalSuspend = buildSuspends("final_suspend"); + if (FinalSuspend.isInvalid()) + return true; + + ScopeInfo->setCoroutineSuspends(InitSuspend.get(), FinalSuspend.get()); + + return true; +} + ExprResult Sema::ActOnCoawaitExpr(Scope *S, SourceLocation Loc, Expr *E) { - auto *Coroutine = checkCoroutineContext(*this, Loc, "co_await"); - if (!Coroutine) { + if (!actOnCoroutineBodyStart(*this, S, Loc, "co_await")) { CorrectDelayedTyposInExpr(E); return ExprError(); } + if (E->getType()->isPlaceholderType()) { ExprResult R = CheckPlaceholderExpr(E); if (R.isInvalid()) return ExprError(); E = R.get(); } + ExprResult Lookup = buildOperatorCoawaitLookupExpr(*this, S, Loc); + if (Lookup.isInvalid()) + return ExprError(); + return BuildUnresolvedCoawaitExpr(Loc, E, + cast<UnresolvedLookupExpr>(Lookup.get())); +} + +ExprResult Sema::BuildUnresolvedCoawaitExpr(SourceLocation Loc, Expr *E, + UnresolvedLookupExpr *Lookup) { + auto *FSI = checkCoroutineContext(*this, Loc, "co_await"); + if (!FSI) + return ExprError(); - ExprResult Awaitable = buildOperatorCoawaitCall(*this, S, Loc, E); + if (E->getType()->isPlaceholderType()) { + ExprResult R = CheckPlaceholderExpr(E); + if (R.isInvalid()) + return ExprError(); + E = R.get(); + } + + auto *Promise = FSI->CoroutinePromise; + if (Promise->getType()->isDependentType()) { + Expr *Res = + new (Context) DependentCoawaitExpr(Loc, Context.DependentTy, E, Lookup); + FSI->CoroutineStmts.push_back(Res); + return Res; + } + + auto *RD = Promise->getType()->getAsCXXRecordDecl(); + if (lookupMember(*this, "await_transform", RD, Loc)) { + ExprResult R = buildPromiseCall(*this, Promise, Loc, "await_transform", E); + if (R.isInvalid()) { + Diag(Loc, + diag::note_coroutine_promise_implicit_await_transform_required_here) + << E->getSourceRange(); + return ExprError(); + } + E = R.get(); + } + ExprResult Awaitable = buildOperatorCoawaitCall(*this, Loc, E, Lookup); if (Awaitable.isInvalid()) return ExprError(); - return BuildCoawaitExpr(Loc, Awaitable.get()); + return BuildResolvedCoawaitExpr(Loc, Awaitable.get()); } -ExprResult Sema::BuildCoawaitExpr(SourceLocation Loc, Expr *E) { + +ExprResult Sema::BuildResolvedCoawaitExpr(SourceLocation Loc, Expr *E, + bool IsImplicit) { auto *Coroutine = checkCoroutineContext(*this, Loc, "co_await"); if (!Coroutine) return ExprError(); @@ -306,8 +461,10 @@ ExprResult Sema::BuildCoawaitExpr(SourceLocation Loc, Expr *E) { } if (E->getType()->isDependentType()) { - Expr *Res = new (Context) CoawaitExpr(Loc, Context.DependentTy, E); - Coroutine->CoroutineStmts.push_back(Res); + Expr *Res = new (Context) + CoawaitExpr(Loc, Context.DependentTy, E, IsImplicit); + if (!IsImplicit) + Coroutine->CoroutineStmts.push_back(Res); return Res; } @@ -322,37 +479,21 @@ ExprResult Sema::BuildCoawaitExpr(SourceLocation Loc, Expr *E) { return ExprError(); Expr *Res = new (Context) CoawaitExpr(Loc, E, RSS.Results[0], RSS.Results[1], - RSS.Results[2]); - Coroutine->CoroutineStmts.push_back(Res); + RSS.Results[2], IsImplicit); + if (!IsImplicit) + Coroutine->CoroutineStmts.push_back(Res); return Res; } -static ExprResult buildPromiseCall(Sema &S, FunctionScopeInfo *Coroutine, - SourceLocation Loc, StringRef Name, - MutableArrayRef<Expr *> Args) { - assert(Coroutine->CoroutinePromise && "no promise for coroutine"); - - // Form a reference to the promise. - auto *Promise = Coroutine->CoroutinePromise; - ExprResult PromiseRef = S.BuildDeclRefExpr( - Promise, Promise->getType().getNonReferenceType(), VK_LValue, Loc); - if (PromiseRef.isInvalid()) - return ExprError(); - - // Call 'yield_value', passing in E. - return buildMemberCall(S, PromiseRef.get(), Loc, Name, Args); -} - ExprResult Sema::ActOnCoyieldExpr(Scope *S, SourceLocation Loc, Expr *E) { - auto *Coroutine = checkCoroutineContext(*this, Loc, "co_yield"); - if (!Coroutine) { + if (!actOnCoroutineBodyStart(*this, S, Loc, "co_yield")) { CorrectDelayedTyposInExpr(E); return ExprError(); } // Build yield_value call. - ExprResult Awaitable = - buildPromiseCall(*this, Coroutine, Loc, "yield_value", E); + ExprResult Awaitable = buildPromiseCall( + *this, getCurFunction()->CoroutinePromise, Loc, "yield_value", E); if (Awaitable.isInvalid()) return ExprError(); @@ -396,18 +537,18 @@ ExprResult Sema::BuildCoyieldExpr(SourceLocation Loc, Expr *E) { return Res; } -StmtResult Sema::ActOnCoreturnStmt(SourceLocation Loc, Expr *E) { - auto *Coroutine = checkCoroutineContext(*this, Loc, "co_return"); - if (!Coroutine) { +StmtResult Sema::ActOnCoreturnStmt(Scope *S, SourceLocation Loc, Expr *E) { + if (!actOnCoroutineBodyStart(*this, S, Loc, "co_return")) { CorrectDelayedTyposInExpr(E); return StmtError(); } return BuildCoreturnStmt(Loc, E); } -StmtResult Sema::BuildCoreturnStmt(SourceLocation Loc, Expr *E) { - auto *Coroutine = checkCoroutineContext(*this, Loc, "co_return"); - if (!Coroutine) +StmtResult Sema::BuildCoreturnStmt(SourceLocation Loc, Expr *E, + bool IsImplicit) { + auto *FSI = checkCoroutineContext(*this, Loc, "co_return"); + if (!FSI) return StmtError(); if (E && E->getType()->isPlaceholderType() && @@ -420,20 +561,22 @@ StmtResult Sema::BuildCoreturnStmt(SourceLocation Loc, Expr *E) { // FIXME: If the operand is a reference to a variable that's about to go out // of scope, we should treat the operand as an xvalue for this overload // resolution. + VarDecl *Promise = FSI->CoroutinePromise; ExprResult PC; if (E && (isa<InitListExpr>(E) || !E->getType()->isVoidType())) { - PC = buildPromiseCall(*this, Coroutine, Loc, "return_value", E); + PC = buildPromiseCall(*this, Promise, Loc, "return_value", E); } else { E = MakeFullDiscardedValueExpr(E).get(); - PC = buildPromiseCall(*this, Coroutine, Loc, "return_void", None); + PC = buildPromiseCall(*this, Promise, Loc, "return_void", None); } if (PC.isInvalid()) return StmtError(); Expr *PCE = ActOnFinishFullExpr(PC.get()).get(); - Stmt *Res = new (Context) CoreturnStmt(Loc, E, PCE); - Coroutine->CoroutineStmts.push_back(Res); + Stmt *Res = new (Context) CoreturnStmt(Loc, E, PCE, IsImplicit); + if (!IsImplicit) + FSI->CoroutineStmts.push_back(Res); return Res; } @@ -490,88 +633,6 @@ static FunctionDecl *findDeleteForPromise(Sema &S, SourceLocation Loc, return OperatorDelete; } -// Builds allocation and deallocation for the coroutine. Returns false on -// failure. -static bool buildAllocationAndDeallocation(Sema &S, SourceLocation Loc, - FunctionScopeInfo *Fn, - Expr *&Allocation, - Expr *&Deallocation) { - TypeSourceInfo *TInfo = Fn->CoroutinePromise->getTypeSourceInfo(); - QualType PromiseType = TInfo->getType(); - if (PromiseType->isDependentType()) - return true; - - if (S.RequireCompleteType(Loc, PromiseType, diag::err_incomplete_type)) - return false; - - // FIXME: Add support for get_return_object_on_allocation failure. - // FIXME: Add support for stateful allocators. - - FunctionDecl *OperatorNew = nullptr; - FunctionDecl *OperatorDelete = nullptr; - FunctionDecl *UnusedResult = nullptr; - bool PassAlignment = false; - - S.FindAllocationFunctions(Loc, SourceRange(), - /*UseGlobal*/ false, PromiseType, - /*isArray*/ false, PassAlignment, - /*PlacementArgs*/ None, OperatorNew, UnusedResult); - - OperatorDelete = findDeleteForPromise(S, Loc, PromiseType); - - if (!OperatorDelete || !OperatorNew) - return false; - - Expr *FramePtr = - buildBuiltinCall(S, Loc, Builtin::BI__builtin_coro_frame, {}); - - Expr *FrameSize = - buildBuiltinCall(S, Loc, Builtin::BI__builtin_coro_size, {}); - - // Make new call. - - ExprResult NewRef = - S.BuildDeclRefExpr(OperatorNew, OperatorNew->getType(), VK_LValue, Loc); - if (NewRef.isInvalid()) - return false; - - ExprResult NewExpr = - S.ActOnCallExpr(S.getCurScope(), NewRef.get(), Loc, FrameSize, Loc); - if (NewExpr.isInvalid()) - return false; - - Allocation = NewExpr.get(); - - // Make delete call. - - QualType OpDeleteQualType = OperatorDelete->getType(); - - ExprResult DeleteRef = - S.BuildDeclRefExpr(OperatorDelete, OpDeleteQualType, VK_LValue, Loc); - if (DeleteRef.isInvalid()) - return false; - - Expr *CoroFree = - buildBuiltinCall(S, Loc, Builtin::BI__builtin_coro_free, {FramePtr}); - - SmallVector<Expr *, 2> DeleteArgs{CoroFree}; - - // Check if we need to pass the size. - const auto *OpDeleteType = - OpDeleteQualType.getTypePtr()->getAs<FunctionProtoType>(); - if (OpDeleteType->getNumParams() > 1) - DeleteArgs.push_back(FrameSize); - - ExprResult DeleteExpr = - S.ActOnCallExpr(S.getCurScope(), DeleteRef.get(), Loc, DeleteArgs, Loc); - if (DeleteExpr.isInvalid()) - return false; - - Deallocation = DeleteExpr.get(); - - return true; -} - namespace { class SubStmtBuilder : public CoroutineBodyStmt::CtorArgs { Sema &S; @@ -595,17 +656,16 @@ public: PromiseRecordDecl = Fn.CoroutinePromise->getType()->getAsCXXRecordDecl(); assert(PromiseRecordDecl && "Type should have already been checked"); } - this->IsValid = makePromiseStmt() && makeInitialSuspend() && - makeFinalSuspend() && makeOnException() && - makeOnFallthrough() && makeNewAndDeleteExpr() && - makeReturnObject() && makeParamMoves(); + this->IsValid = makePromiseStmt() && makeInitialAndFinalSuspend() && + makeOnException() && makeOnFallthrough() && + makeNewAndDeleteExpr() && makeReturnObject() && + makeParamMoves(); } bool isInvalid() const { return !this->IsValid; } bool makePromiseStmt(); - bool makeInitialSuspend(); - bool makeFinalSuspend(); + bool makeInitialAndFinalSuspend(); bool makeNewAndDeleteExpr(); bool makeOnFallthrough(); bool makeOnException(); @@ -616,7 +676,7 @@ public: void Sema::CheckCompletedCoroutineBody(FunctionDecl *FD, Stmt *&Body) { FunctionScopeInfo *Fn = getCurFunction(); - assert(Fn && !Fn->CoroutineStmts.empty() && "not a coroutine"); + assert(Fn && Fn->CoroutinePromise && "not a coroutine"); // Coroutines [stmt.return]p1: // A return statement shall not appear in a coroutine. @@ -624,8 +684,8 @@ void Sema::CheckCompletedCoroutineBody(FunctionDecl *FD, Stmt *&Body) { Diag(Fn->FirstReturnLoc, diag::err_return_in_coroutine); auto *First = Fn->CoroutineStmts[0]; Diag(First->getLocStart(), diag::note_declared_coroutine_here) - << (isa<CoawaitExpr>(First) ? 0 : - isa<CoyieldExpr>(First) ? 1 : 2); + << (isa<CoawaitExpr>(First) ? "co_await" : + isa<CoyieldExpr>(First) ? "co_yield" : "co_return"); } SubStmtBuilder Builder(*this, *FD, *Fn, Body); if (Builder.isInvalid()) @@ -647,40 +707,88 @@ bool SubStmtBuilder::makePromiseStmt() { return true; } -bool SubStmtBuilder::makeInitialSuspend() { - // Form and check implicit 'co_await p.initial_suspend();' statement. - ExprResult InitialSuspend = - buildPromiseCall(S, &Fn, Loc, "initial_suspend", None); - // FIXME: Support operator co_await here. - if (!InitialSuspend.isInvalid()) - InitialSuspend = S.BuildCoawaitExpr(Loc, InitialSuspend.get()); - InitialSuspend = S.ActOnFinishFullExpr(InitialSuspend.get()); - if (InitialSuspend.isInvalid()) +bool SubStmtBuilder::makeInitialAndFinalSuspend() { + if (Fn.hasInvalidCoroutineSuspends()) return false; - - this->InitialSuspend = InitialSuspend.get(); + this->InitialSuspend = cast<Expr>(Fn.CoroutineSuspends.first); + this->FinalSuspend = cast<Expr>(Fn.CoroutineSuspends.second); return true; } -bool SubStmtBuilder::makeFinalSuspend() { - // Form and check implicit 'co_await p.final_suspend();' statement. - ExprResult FinalSuspend = - buildPromiseCall(S, &Fn, Loc, "final_suspend", None); - // FIXME: Support operator co_await here. - if (!FinalSuspend.isInvalid()) - FinalSuspend = S.BuildCoawaitExpr(Loc, FinalSuspend.get()); - FinalSuspend = S.ActOnFinishFullExpr(FinalSuspend.get()); - if (FinalSuspend.isInvalid()) +bool SubStmtBuilder::makeNewAndDeleteExpr() { + // Form and check allocation and deallocation calls. + QualType PromiseType = Fn.CoroutinePromise->getType(); + if (PromiseType->isDependentType()) + return true; + + if (S.RequireCompleteType(Loc, PromiseType, diag::err_incomplete_type)) return false; - this->FinalSuspend = FinalSuspend.get(); - return true; -} + // FIXME: Add support for get_return_object_on_allocation failure. + // FIXME: Add support for stateful allocators. -bool SubStmtBuilder::makeNewAndDeleteExpr() { - // Form and check allocation and deallocation calls. - return buildAllocationAndDeallocation(S, Loc, &Fn, this->Allocate, - this->Deallocate); + FunctionDecl *OperatorNew = nullptr; + FunctionDecl *OperatorDelete = nullptr; + FunctionDecl *UnusedResult = nullptr; + bool PassAlignment = false; + + S.FindAllocationFunctions(Loc, SourceRange(), + /*UseGlobal*/ false, PromiseType, + /*isArray*/ false, PassAlignment, + /*PlacementArgs*/ None, OperatorNew, UnusedResult); + + OperatorDelete = findDeleteForPromise(S, Loc, PromiseType); + + if (!OperatorDelete || !OperatorNew) + return false; + + Expr *FramePtr = + buildBuiltinCall(S, Loc, Builtin::BI__builtin_coro_frame, {}); + + Expr *FrameSize = + buildBuiltinCall(S, Loc, Builtin::BI__builtin_coro_size, {}); + + // Make new call. + + ExprResult NewRef = + S.BuildDeclRefExpr(OperatorNew, OperatorNew->getType(), VK_LValue, Loc); + if (NewRef.isInvalid()) + return false; + + ExprResult NewExpr = + S.ActOnCallExpr(S.getCurScope(), NewRef.get(), Loc, FrameSize, Loc); + if (NewExpr.isInvalid()) + return false; + + // Make delete call. + + QualType OpDeleteQualType = OperatorDelete->getType(); + + ExprResult DeleteRef = + S.BuildDeclRefExpr(OperatorDelete, OpDeleteQualType, VK_LValue, Loc); + if (DeleteRef.isInvalid()) + return false; + + Expr *CoroFree = + buildBuiltinCall(S, Loc, Builtin::BI__builtin_coro_free, {FramePtr}); + + SmallVector<Expr *, 2> DeleteArgs{CoroFree}; + + // Check if we need to pass the size. + const auto *OpDeleteType = + OpDeleteQualType.getTypePtr()->getAs<FunctionProtoType>(); + if (OpDeleteType->getNumParams() > 1) + DeleteArgs.push_back(FrameSize); + + ExprResult DeleteExpr = + S.ActOnCallExpr(S.getCurScope(), DeleteRef.get(), Loc, DeleteArgs, Loc); + if (DeleteExpr.isInvalid()) + return false; + + this->Allocate = NewExpr.get(); + this->Deallocate = DeleteExpr.get(); + + return true; } bool SubStmtBuilder::makeOnFallthrough() { @@ -690,13 +798,8 @@ bool SubStmtBuilder::makeOnFallthrough() { // [dcl.fct.def.coroutine]/4 // The unqualified-ids 'return_void' and 'return_value' are looked up in // the scope of class P. If both are found, the program is ill-formed. - DeclarationName RVoidDN = S.PP.getIdentifierInfo("return_void"); - LookupResult RVoidResult(S, RVoidDN, Loc, Sema::LookupMemberName); - const bool HasRVoid = S.LookupQualifiedName(RVoidResult, PromiseRecordDecl); - - DeclarationName RValueDN = S.PP.getIdentifierInfo("return_value"); - LookupResult RValueResult(S, RValueDN, Loc, Sema::LookupMemberName); - const bool HasRValue = S.LookupQualifiedName(RValueResult, PromiseRecordDecl); + const bool HasRVoid = lookupMember(S, "return_void", PromiseRecordDecl, Loc); + const bool HasRValue = lookupMember(S, "return_value", PromiseRecordDecl, Loc); StmtResult Fallthrough; if (HasRVoid && HasRValue) { @@ -708,7 +811,8 @@ bool SubStmtBuilder::makeOnFallthrough() { // If the unqualified-id return_void is found, flowing off the end of a // coroutine is equivalent to a co_return with no operand. Otherwise, // flowing off the end of a coroutine results in undefined behavior. - Fallthrough = S.BuildCoreturnStmt(FD.getLocation(), nullptr); + Fallthrough = S.BuildCoreturnStmt(FD.getLocation(), nullptr, + /*IsImplicit*/false); Fallthrough = S.ActOnFinishFullStmt(Fallthrough.get()); if (Fallthrough.isInvalid()) return false; @@ -736,15 +840,13 @@ bool SubStmtBuilder::makeOnException() { // [dcl.fct.def.coroutine]/3 // The unqualified-id set_exception is found in the scope of P by class // member access lookup (3.4.5). - DeclarationName SetExDN = S.PP.getIdentifierInfo("set_exception"); - LookupResult SetExResult(S, SetExDN, Loc, Sema::LookupMemberName); - if (S.LookupQualifiedName(SetExResult, PromiseRecordDecl)) { + if (lookupMember(S, "set_exception", PromiseRecordDecl, Loc)) { // Form the call 'p.set_exception(std::current_exception())' SetException = buildStdCurrentExceptionCall(S, Loc); if (SetException.isInvalid()) return false; Expr *E = SetException.get(); - SetException = buildPromiseCall(S, &Fn, Loc, "set_exception", E); + SetException = buildPromiseCall(S, Fn.CoroutinePromise, Loc, "set_exception", E); SetException = S.ActOnFinishFullExpr(SetException.get(), Loc); if (SetException.isInvalid()) return false; @@ -759,7 +861,7 @@ bool SubStmtBuilder::makeReturnObject() { // Build implicit 'p.get_return_object()' expression and form initialization // of return type from it. ExprResult ReturnObject = - buildPromiseCall(S, &Fn, Loc, "get_return_object", None); + buildPromiseCall(S, Fn.CoroutinePromise, Loc, "get_return_object", None); if (ReturnObject.isInvalid()) return false; QualType RetType = FD.getReturnType(); @@ -783,3 +885,10 @@ bool SubStmtBuilder::makeParamMoves() { // FIXME: Perform move-initialization of parameters into frame-local copies. return true; } + +StmtResult Sema::BuildCoroutineBodyStmt(CoroutineBodyStmt::CtorArgs Args) { + CoroutineBodyStmt *Res = CoroutineBodyStmt::Create(Context, Args); + if (!Res) + return StmtError(); + return Res; +} diff --git a/clang/lib/Sema/SemaDecl.cpp b/clang/lib/Sema/SemaDecl.cpp index 40ab1d29ae8..d7d71221b5d 100644 --- a/clang/lib/Sema/SemaDecl.cpp +++ b/clang/lib/Sema/SemaDecl.cpp @@ -11989,7 +11989,7 @@ Decl *Sema::ActOnFinishFunctionBody(Decl *dcl, Stmt *Body, sema::AnalysisBasedWarnings::Policy WP = AnalysisWarnings.getDefaultPolicy(); sema::AnalysisBasedWarnings::Policy *ActivePolicy = nullptr; - if (getLangOpts().CoroutinesTS && !getCurFunction()->CoroutineStmts.empty()) + if (getLangOpts().CoroutinesTS && getCurFunction()->CoroutinePromise) CheckCompletedCoroutineBody(FD, Body); if (FD) { diff --git a/clang/lib/Sema/SemaExceptionSpec.cpp b/clang/lib/Sema/SemaExceptionSpec.cpp index 2ac2aca6f66..deb6cbb53af 100644 --- a/clang/lib/Sema/SemaExceptionSpec.cpp +++ b/clang/lib/Sema/SemaExceptionSpec.cpp @@ -1182,6 +1182,7 @@ CanThrowResult Sema::canThrow(const Expr *E) { case Expr::ArraySubscriptExprClass: case Expr::OMPArraySectionExprClass: case Expr::BinaryOperatorClass: + case Expr::DependentCoawaitExprClass: case Expr::CompoundAssignOperatorClass: case Expr::CStyleCastExprClass: case Expr::CXXStaticCastExprClass: diff --git a/clang/lib/Sema/TreeTransform.h b/clang/lib/Sema/TreeTransform.h index 8a63d354746..4e22762eb19 100644 --- a/clang/lib/Sema/TreeTransform.h +++ b/clang/lib/Sema/TreeTransform.h @@ -1362,16 +1362,28 @@ public: /// /// By default, performs semantic analysis to build the new statement. /// Subclasses may override this routine to provide different behavior. - StmtResult RebuildCoreturnStmt(SourceLocation CoreturnLoc, Expr *Result) { - return getSema().BuildCoreturnStmt(CoreturnLoc, Result); + StmtResult RebuildCoreturnStmt(SourceLocation CoreturnLoc, Expr *Result, + bool IsImplicit) { + return getSema().BuildCoreturnStmt(CoreturnLoc, Result, IsImplicit); } /// \brief Build a new co_await expression. /// /// By default, performs semantic analysis to build the new expression. /// Subclasses may override this routine to provide different behavior. - ExprResult RebuildCoawaitExpr(SourceLocation CoawaitLoc, Expr *Result) { - return getSema().BuildCoawaitExpr(CoawaitLoc, Result); + ExprResult RebuildCoawaitExpr(SourceLocation CoawaitLoc, Expr *Result, + bool IsImplicit) { + return getSema().BuildResolvedCoawaitExpr(CoawaitLoc, Result, IsImplicit); + } + + /// \brief Build a new co_await expression. + /// + /// By default, performs semantic analysis to build the new expression. + /// Subclasses may override this routine to provide different behavior. + ExprResult RebuildDependentCoawaitExpr(SourceLocation CoawaitLoc, + Expr *Result, + UnresolvedLookupExpr *Lookup) { + return getSema().BuildUnresolvedCoawaitExpr(CoawaitLoc, Result, Lookup); } /// \brief Build a new co_yield expression. @@ -1382,6 +1394,10 @@ public: return getSema().BuildCoyieldExpr(CoyieldLoc, Result); } + StmtResult RebuildCoroutineBodyStmt(CoroutineBodyStmt::CtorArgs Args) { + return getSema().BuildCoroutineBodyStmt(Args); + } + /// \brief Build a new Objective-C \@try statement. /// /// By default, performs semantic analysis to build the new statement. @@ -6833,7 +6849,91 @@ StmtResult TreeTransform<Derived>::TransformCoroutineBodyStmt(CoroutineBodyStmt *S) { // The coroutine body should be re-formed by the caller if necessary. // FIXME: The coroutine body is always rebuilt by ActOnFinishFunctionBody - return getDerived().TransformStmt(S->getBody()); + CoroutineBodyStmt::CtorArgs BodyArgs; + + auto *ScopeInfo = SemaRef.getCurFunction(); + auto *FD = cast<FunctionDecl>(SemaRef.CurContext); + assert(ScopeInfo && !ScopeInfo->CoroutinePromise && + ScopeInfo->NeedsCoroutineSuspends && + ScopeInfo->CoroutineSuspends.first == nullptr && + ScopeInfo->CoroutineSuspends.second == nullptr && + ScopeInfo->CoroutineStmts.empty() && "expected clean scope info"); + + // Set that we have (possibly-invalid) suspend points before we do anything + // that may fail. + ScopeInfo->setNeedsCoroutineSuspends(false); + + // The new CoroutinePromise object needs to be built and put into the current + // FunctionScopeInfo before any transformations or rebuilding occurs. + auto *Promise = S->getPromiseDecl(); + auto *NewPromise = SemaRef.buildCoroutinePromise(FD->getLocation()); + if (!NewPromise) + return StmtError(); + getDerived().transformedLocalDecl(Promise, NewPromise); + ScopeInfo->CoroutinePromise = NewPromise; + StmtResult PromiseStmt = SemaRef.ActOnDeclStmt( + SemaRef.ConvertDeclToDeclGroup(NewPromise), + FD->getLocation(), FD->getLocation()); + assert(!PromiseStmt.isInvalid()); + BodyArgs.Promise = PromiseStmt.get(); + + // Transform the implicit coroutine statements we built during the initial + // parse. + StmtResult InitSuspend = getDerived().TransformStmt(S->getInitSuspendStmt()); + if (InitSuspend.isInvalid()) + return StmtError(); + StmtResult FinalSuspend = + getDerived().TransformStmt(S->getFinalSuspendStmt()); + if (FinalSuspend.isInvalid()) + return StmtError(); + ScopeInfo->setCoroutineSuspends(InitSuspend.get(), FinalSuspend.get()); + assert(isa<Expr>(InitSuspend.get()) && isa<Expr>(FinalSuspend.get())); + BodyArgs.InitialSuspend = cast<Expr>(InitSuspend.get()); + BodyArgs.FinalSuspend = cast<Expr>(FinalSuspend.get()); + + StmtResult BodyRes = getDerived().TransformStmt(S->getBody()); + if (BodyRes.isInvalid()) + return StmtError(); + BodyArgs.Body = BodyRes.get(); + + if (S->getFallthroughHandler()) { + StmtResult Res = getDerived().TransformStmt(S->getFallthroughHandler()); + if (Res.isInvalid()) + return StmtError(); + BodyArgs.OnFallthrough = Res.get(); + } + + if (S->getExceptionHandler()) { + StmtResult Res = getDerived().TransformStmt(S->getExceptionHandler()); + if (Res.isInvalid()) + return StmtError(); + BodyArgs.OnException = Res.get(); + } + + // Transform any additional statements we may have already built + if (S->getAllocate() && S->getDeallocate()) { + ExprResult AllocRes = getDerived().TransformExpr(S->getAllocate()); + if (AllocRes.isInvalid()) + return StmtError(); + BodyArgs.Allocate = AllocRes.get(); + + ExprResult DeallocRes = getDerived().TransformExpr(S->getDeallocate()); + if (DeallocRes.isInvalid()) + return StmtError(); + BodyArgs.Deallocate = DeallocRes.get(); + } + + Expr *ReturnObject = S->getReturnValueInit(); + if (ReturnObject) { + ExprResult Res = getDerived().TransformInitializer(ReturnObject, + /*NoCopyInit*/false); + if (Res.isInvalid()) + return StmtError(); + BodyArgs.ReturnValue = Res.get(); + } + + // Do a partial rebuild of the coroutine body and stash it in the ScopeInfo + return getDerived().RebuildCoroutineBodyStmt(BodyArgs); } template<typename Derived> @@ -6846,7 +6946,8 @@ TreeTransform<Derived>::TransformCoreturnStmt(CoreturnStmt *S) { // Always rebuild; we don't know if this needs to be injected into a new // context or if the promise type has changed. - return getDerived().RebuildCoreturnStmt(S->getKeywordLoc(), Result.get()); + return getDerived().RebuildCoreturnStmt(S->getKeywordLoc(), Result.get(), + S->isImplicit()); } template<typename Derived> @@ -6859,7 +6960,29 @@ TreeTransform<Derived>::TransformCoawaitExpr(CoawaitExpr *E) { // Always rebuild; we don't know if this needs to be injected into a new // context or if the promise type has changed. - return getDerived().RebuildCoawaitExpr(E->getKeywordLoc(), Result.get()); + return getDerived().RebuildCoawaitExpr(E->getKeywordLoc(), Result.get(), + E->isImplicit()); +} + +template <typename Derived> +ExprResult +TreeTransform<Derived>::TransformDependentCoawaitExpr(DependentCoawaitExpr *E) { + ExprResult OperandResult = getDerived().TransformInitializer(E->getOperand(), + /*NotCopyInit*/ false); + if (OperandResult.isInvalid()) + return ExprError(); + + ExprResult LookupResult = getDerived().TransformUnresolvedLookupExpr( + E->getOperatorCoawaitLookup()); + + if (LookupResult.isInvalid()) + return ExprError(); + + // Always rebuild; we don't know if this needs to be injected into a new + // context or if the promise type has changed. + return getDerived().RebuildDependentCoawaitExpr( + E->getKeywordLoc(), OperandResult.get(), + cast<UnresolvedLookupExpr>(LookupResult.get())); } template<typename Derived> |