talent-plan-tinykv/kv/coprocessor/cop_handler.go
Connor dd377dedaa Add support of coprocessor (#155)
* add coprocessor support

Signed-off-by: Connor1996 <zbk602423539@gmail.com>

Co-authored-by: Yiding Cui <winoros@gmail.com>
2020-04-30 15:31:26 +08:00

341 lines
10 KiB
Go

// Copyright 2019-present PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.
package coprocessor
import (
"bytes"
"time"
"github.com/golang/protobuf/proto"
"github.com/juju/errors"
"github.com/pingcap-incubator/tinykv/kv/coprocessor/rowcodec"
"github.com/pingcap-incubator/tinykv/kv/storage"
"github.com/pingcap-incubator/tinykv/kv/transaction/mvcc"
"github.com/pingcap-incubator/tinykv/proto/pkg/coprocessor"
"github.com/pingcap/tidb/expression"
"github.com/pingcap/tidb/expression/aggregation"
"github.com/pingcap/tidb/kv"
"github.com/pingcap/tidb/parser/mysql"
"github.com/pingcap/tidb/parser/terror"
"github.com/pingcap/tidb/sessionctx/stmtctx"
"github.com/pingcap/tidb/tablecodec"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tipb/go-tipb"
)
var dummySlice = make([]byte, 0)
type dagContext struct {
reader storage.StorageReader
dagReq *tipb.DAGRequest
keyRanges []*coprocessor.KeyRange
evalCtx *evalContext
startTS uint64
}
type CopHandler struct{}
func (svr *CopHandler) HandleCopDAGRequest(reader storage.StorageReader, req *coprocessor.Request) *coprocessor.Response {
startTime := time.Now()
resp := &coprocessor.Response{}
dagCtx, dagReq, err := svr.buildDAG(reader, req)
if err != nil {
resp.OtherError = err.Error()
return resp
}
closureExec, err := svr.buildClosureExecutor(dagCtx, dagReq)
if err != nil {
return buildResp(nil, nil, err, dagCtx.evalCtx.sc.GetWarnings(), time.Since(startTime))
}
chunks, err := closureExec.execute()
return buildResp(chunks, nil, err, dagCtx.evalCtx.sc.GetWarnings(), time.Since(startTime))
}
func (svr *CopHandler) buildDAG(reader storage.StorageReader, req *coprocessor.Request) (*dagContext, *tipb.DAGRequest, error) {
if len(req.Ranges) == 0 {
return nil, nil, errors.New("request range is null")
}
if req.GetTp() != kv.ReqTypeDAG {
return nil, nil, errors.Errorf("unsupported request type %d", req.GetTp())
}
dagReq := new(tipb.DAGRequest)
err := proto.Unmarshal(req.Data, dagReq)
if err != nil {
return nil, nil, errors.Trace(err)
}
sc := flagsToStatementContext(dagReq.Flags)
sc.TimeZone = time.FixedZone("UTC", int(dagReq.TimeZoneOffset))
ctx := &dagContext{
reader: reader,
dagReq: dagReq,
keyRanges: req.Ranges,
evalCtx: &evalContext{sc: sc},
startTS: req.StartTs,
}
scanExec := dagReq.Executors[0]
if scanExec.Tp == tipb.ExecType_TypeTableScan {
ctx.evalCtx.setColumnInfo(scanExec.TblScan.Columns)
} else {
ctx.evalCtx.setColumnInfo(scanExec.IdxScan.Columns)
}
return ctx, dagReq, err
}
func (svr *CopHandler) getAggInfo(ctx *dagContext, pbAgg *tipb.Aggregation) ([]aggregation.Aggregation, []expression.Expression, error) {
length := len(pbAgg.AggFunc)
aggs := make([]aggregation.Aggregation, 0, length)
var err error
for _, expr := range pbAgg.AggFunc {
var aggExpr aggregation.Aggregation
aggExpr, err = aggregation.NewDistAggFunc(expr, ctx.evalCtx.fieldTps, ctx.evalCtx.sc)
if err != nil {
return nil, nil, errors.Trace(err)
}
aggs = append(aggs, aggExpr)
}
groupBys, err := convertToExprs(ctx.evalCtx.sc, ctx.evalCtx.fieldTps, pbAgg.GetGroupBy())
if err != nil {
return nil, nil, errors.Trace(err)
}
return aggs, groupBys, nil
}
func (svr *CopHandler) getTopNInfo(ctx *evalContext, topN *tipb.TopN) (heap *topNHeap, conds []expression.Expression, err error) {
pbConds := make([]*tipb.Expr, len(topN.OrderBy))
for i, item := range topN.OrderBy {
pbConds[i] = item.Expr
}
heap = &topNHeap{
totalCount: int(topN.Limit),
topNSorter: topNSorter{
orderByItems: topN.OrderBy,
sc: ctx.sc,
},
}
if conds, err = convertToExprs(ctx.sc, ctx.fieldTps, pbConds); err != nil {
return nil, nil, errors.Trace(err)
}
return heap, conds, nil
}
type evalContext struct {
colIDs map[int64]int
columnInfos []*tipb.ColumnInfo
fieldTps []*types.FieldType
sc *stmtctx.StatementContext
}
func (e *evalContext) setColumnInfo(cols []*tipb.ColumnInfo) {
e.columnInfos = make([]*tipb.ColumnInfo, len(cols))
copy(e.columnInfos, cols)
e.colIDs = make(map[int64]int, len(e.columnInfos))
e.fieldTps = make([]*types.FieldType, 0, len(e.columnInfos))
for i, col := range e.columnInfos {
ft := fieldTypeFromPBColumn(col)
e.fieldTps = append(e.fieldTps, ft)
e.colIDs[col.GetColumnId()] = i
}
}
func (e *evalContext) newRowDecoder() (*rowcodec.Decoder, error) {
colIDs := make([]int64, len(e.columnInfos))
defaultVals := make([][]byte, len(e.columnInfos))
var handleColID int64
for i, colInfo := range e.columnInfos {
colIDs[i] = colInfo.ColumnId
defaultVals[i] = colInfo.DefaultVal
if colInfo.PkHandle {
handleColID = colInfo.ColumnId
}
}
return rowcodec.NewDecoder(colIDs, handleColID, e.fieldTps, defaultVals, e.sc.TimeZone)
}
func (e *evalContext) newRowDecoderForOffsets(colOffsets []int) (*rowcodec.Decoder, error) {
var (
handleColID int64
colIDs = make([]int64, len(colOffsets))
defaultVals = make([][]byte, len(colOffsets))
fieldsTps = make([]*types.FieldType, len(colOffsets))
)
for i, off := range colOffsets {
info := e.columnInfos[off]
colIDs[i] = info.ColumnId
defaultVals[i] = info.DefaultVal
fieldsTps[i] = e.fieldTps[off]
if info.PkHandle {
handleColID = info.ColumnId
}
}
return rowcodec.NewDecoder(colIDs, handleColID, fieldsTps, defaultVals, e.sc.TimeZone)
}
// decodeRelatedColumnVals decodes data to Datum slice according to the row information.
func (e *evalContext) decodeRelatedColumnVals(relatedColOffsets []int, value [][]byte, row []types.Datum) error {
var err error
for _, offset := range relatedColOffsets {
row[offset], err = tablecodec.DecodeColumnValue(value[offset], e.fieldTps[offset], e.sc.TimeZone)
if err != nil {
return errors.Trace(err)
}
}
return nil
}
// Flags are used by tipb.SelectRequest.Flags to handle execution mode, like how to handle truncate error.
const (
// FlagIgnoreTruncate indicates if truncate error should be ignored.
// Read-only statements should ignore truncate error, write statements should not ignore truncate error.
FlagIgnoreTruncate uint64 = 1
// FlagTruncateAsWarning indicates if truncate error should be returned as warning.
// This flag only matters if FlagIgnoreTruncate is not set, in strict sql mode, truncate error should
// be returned as error, in non-strict sql mode, truncate error should be saved as warning.
FlagTruncateAsWarning uint64 = 1 << 1
// FlagPadCharToFullLength indicates if sql_mode 'PAD_CHAR_TO_FULL_LENGTH' is set.
FlagPadCharToFullLength uint64 = 1 << 2
)
// flagsToStatementContext creates a StatementContext from a `tipb.SelectRequest.Flags`.
func flagsToStatementContext(flags uint64) *stmtctx.StatementContext {
sc := &stmtctx.StatementContext{
IgnoreTruncate: (flags & FlagIgnoreTruncate) > 0,
TruncateAsWarning: (flags & FlagTruncateAsWarning) > 0,
PadCharToFullLength: (flags & FlagPadCharToFullLength) > 0,
}
return sc
}
func buildResp(chunks []tipb.Chunk, counts []int64, err error, warnings []stmtctx.SQLWarn, dur time.Duration) *coprocessor.Response {
resp := &coprocessor.Response{}
selResp := &tipb.SelectResponse{
Error: toPBError(err),
Chunks: chunks,
OutputCounts: counts,
}
if len(warnings) > 0 {
selResp.Warnings = make([]*tipb.Error, 0, len(warnings))
for i := range warnings {
selResp.Warnings = append(selResp.Warnings, toPBError(warnings[i].Err))
}
}
if err != nil {
if ke, ok := err.(*mvcc.KeyError); ok {
if ke.Locked == nil {
resp.OtherError = ke.Error()
} else {
resp.Locked = ke.Locked
}
} else {
resp.OtherError = err.Error()
}
}
data, err := proto.Marshal(selResp)
if err != nil {
resp.OtherError = err.Error()
return resp
}
resp.Data = data
return resp
}
func toPBError(err error) *tipb.Error {
if err == nil {
return nil
}
perr := new(tipb.Error)
switch x := err.(type) {
case *terror.Error:
sqlErr := x.ToSQLError()
perr.Code = int32(sqlErr.Code)
perr.Msg = sqlErr.Message
default:
perr.Code = int32(1)
perr.Msg = err.Error()
}
return perr
}
// extractKVRanges extracts kv.KeyRanges slice from a SelectRequest.
func (svr *CopHandler) extractKVRanges(reader storage.StorageReader, keyRanges []*coprocessor.KeyRange, descScan bool) (kvRanges []kv.KeyRange, err error) {
kvRanges = make([]kv.KeyRange, 0, len(keyRanges))
for _, kran := range keyRanges {
if bytes.Compare(kran.GetStart(), kran.GetEnd()) >= 0 {
err = errors.Errorf("invalid range, start should be smaller than end: %v %v", kran.GetStart(), kran.GetEnd())
return
}
upperKey := kran.GetEnd()
lowerKey := kran.GetStart()
r := kv.KeyRange{
StartKey: kv.Key(lowerKey),
EndKey: kv.Key(upperKey),
}
kvRanges = append(kvRanges, r)
}
if descScan {
reverseKVRanges(kvRanges)
}
return
}
func reverseKVRanges(kvRanges []kv.KeyRange) {
for i := 0; i < len(kvRanges)/2; i++ {
j := len(kvRanges) - i - 1
kvRanges[i], kvRanges[j] = kvRanges[j], kvRanges[i]
}
}
const rowsPerChunk = 64
func appendRow(chunks []tipb.Chunk, data []byte, rowCnt int) []tipb.Chunk {
if rowCnt%rowsPerChunk == 0 {
chunks = append(chunks, tipb.Chunk{})
}
cur := &chunks[len(chunks)-1]
cur.RowsData = append(cur.RowsData, data...)
return chunks
}
func maxStartKey(rangeStartKey kv.Key, regionStartKey []byte) []byte {
if bytes.Compare([]byte(rangeStartKey), regionStartKey) > 0 {
return []byte(rangeStartKey)
}
return regionStartKey
}
func minEndKey(rangeEndKey kv.Key, regionEndKey []byte) []byte {
if len(regionEndKey) == 0 || bytes.Compare([]byte(rangeEndKey), regionEndKey) < 0 {
return []byte(rangeEndKey)
}
return regionEndKey
}
// fieldTypeFromPBColumn creates a types.FieldType from tipb.ColumnInfo.
func fieldTypeFromPBColumn(col *tipb.ColumnInfo) *types.FieldType {
return &types.FieldType{
Tp: byte(col.GetTp()),
Flag: uint(col.Flag),
Flen: int(col.GetColumnLen()),
Decimal: int(col.GetDecimal()),
Elems: col.Elems,
Collate: mysql.Collations[uint8(col.GetCollation())],
}
}