增加数据权限的 SQL 重写的上下文

pull/2/head
YunaiV 2021-12-10 10:08:29 +08:00
parent e9ba4ac705
commit eda2b11dad
5 changed files with 191 additions and 30 deletions

View File

@ -29,8 +29,8 @@
<!-- Test 测试相关 --> <!-- Test 测试相关 -->
<dependency> <dependency>
<groupId>org.springframework.boot</groupId> <groupId>cn.iocoder.boot</groupId>
<artifactId>spring-boot-starter-test</artifactId> <artifactId>yudao-spring-boot-starter-test</artifactId>
<scope>test</scope> <scope>test</scope>
</dependency> </dependency>
</dependencies> </dependencies>

View File

@ -1,11 +1,16 @@
package cn.iocoder.yudao.framework.datapermission.core.interceptor; package cn.iocoder.yudao.framework.datapermission.core.interceptor;
import com.baomidou.mybatisplus.core.plugins.InterceptorIgnoreHelper; import cn.hutool.core.collection.CollUtil;
import cn.iocoder.yudao.framework.common.util.collection.SetUtils;
import cn.iocoder.yudao.framework.datapermission.core.rule.DataPermissionRule;
import cn.iocoder.yudao.framework.datapermission.core.rule.DataPermissionRuleFactory;
import com.alibaba.ttl.TransmittableThreadLocal;
import com.baomidou.mybatisplus.core.toolkit.CollectionUtils; import com.baomidou.mybatisplus.core.toolkit.CollectionUtils;
import com.baomidou.mybatisplus.core.toolkit.PluginUtils; import com.baomidou.mybatisplus.core.toolkit.PluginUtils;
import com.baomidou.mybatisplus.core.toolkit.StringPool; import com.baomidou.mybatisplus.core.toolkit.StringPool;
import com.baomidou.mybatisplus.extension.parser.JsqlParserSupport; import com.baomidou.mybatisplus.extension.parser.JsqlParserSupport;
import com.baomidou.mybatisplus.extension.plugins.inner.InnerInterceptor; import com.baomidou.mybatisplus.extension.plugins.inner.InnerInterceptor;
import lombok.RequiredArgsConstructor;
import net.sf.jsqlparser.expression.*; import net.sf.jsqlparser.expression.*;
import net.sf.jsqlparser.expression.operators.conditional.AndExpression; import net.sf.jsqlparser.expression.operators.conditional.AndExpression;
import net.sf.jsqlparser.expression.operators.conditional.OrExpression; import net.sf.jsqlparser.expression.operators.conditional.OrExpression;
@ -24,33 +29,58 @@ import org.apache.ibatis.session.ResultHandler;
import org.apache.ibatis.session.RowBounds; import org.apache.ibatis.session.RowBounds;
import java.sql.Connection; import java.sql.Connection;
import java.util.Collection; import java.util.*;
import java.util.Deque; import java.util.concurrent.ConcurrentHashMap;
import java.util.LinkedList;
import java.util.List;
@RequiredArgsConstructor
public class DataPermissionInterceptor extends JsqlParserSupport implements InnerInterceptor { public class DataPermissionInterceptor extends JsqlParserSupport implements InnerInterceptor {
// private TenantLineHandler tenantLineHandler; private final DataPermissionRuleFactory ruleFactory;
@Override private final MappedStatementCache mappedStatementCache = new MappedStatementCache();
@Override // SELECT 场景
public void beforeQuery(Executor executor, MappedStatement ms, Object parameter, RowBounds rowBounds, ResultHandler resultHandler, BoundSql boundSql) { public void beforeQuery(Executor executor, MappedStatement ms, Object parameter, RowBounds rowBounds, ResultHandler resultHandler, BoundSql boundSql) {
// TODO 芋艿:这个判断,后续读懂下 // 获得 Mapper 对应的数据权限的规则
if (InterceptorIgnoreHelper.willIgnoreTenantLine(ms.getId())) return; List<DataPermissionRule> rules = ruleFactory.getDataPermissionRule(ms.getId());
PluginUtils.MPBoundSql mpBs = PluginUtils.mpBoundSql(boundSql); if (mappedStatementCache.noRewritable(ms, rules)) { // 如果无需重写,则跳过
// TODO 芋艿null=》DataScope return;
mpBs.sql(parserSingle(mpBs.sql(), null));
} }
@Override PluginUtils.MPBoundSql mpBs = PluginUtils.mpBoundSql(boundSql);
try {
// 初始化上下文
ContextHolder.init(rules);
// 处理 SQL
mpBs.sql(parserSingle(mpBs.sql(), null));
} finally {
addMappedStatementCache(ms);
ContextHolder.clear();
}
}
@Override // 只处理 UPDATE / DELETE 场景
public void beforePrepare(StatementHandler sh, Connection connection, Integer transactionTimeout) { public void beforePrepare(StatementHandler sh, Connection connection, Integer transactionTimeout) {
PluginUtils.MPStatementHandler mpSh = PluginUtils.mpStatementHandler(sh); PluginUtils.MPStatementHandler mpSh = PluginUtils.mpStatementHandler(sh);
MappedStatement ms = mpSh.mappedStatement(); MappedStatement ms = mpSh.mappedStatement();
SqlCommandType sct = ms.getSqlCommandType(); SqlCommandType sct = ms.getSqlCommandType();
if (sct == SqlCommandType.UPDATE || sct == SqlCommandType.DELETE) { // 无需处理 Insert 语句 if (sct == SqlCommandType.UPDATE || sct == SqlCommandType.DELETE) { // 无需处理 Insert 语句
if (InterceptorIgnoreHelper.willIgnoreTenantLine(ms.getId())) return; // 获得 Mapper 对应的数据权限的规则
List<DataPermissionRule> rules = ruleFactory.getDataPermissionRule(ms.getId());
if (mappedStatementCache.noRewritable(ms, rules)) { // 如果无需重写,则跳过
return;
}
PluginUtils.MPBoundSql mpBs = mpSh.mPBoundSql(); PluginUtils.MPBoundSql mpBs = mpSh.mPBoundSql();
try {
// 初始化上下文
ContextHolder.init(rules);
// 处理 SQL
mpBs.sql(parserMulti(mpBs.sql(), null)); mpBs.sql(parserMulti(mpBs.sql(), null));
} finally {
addMappedStatementCache(ms);
ContextHolder.clear();
}
} }
} }
@ -87,10 +117,6 @@ public class DataPermissionInterceptor extends JsqlParserSupport implements Inne
@Override @Override
protected void processUpdate(Update update, int index, String sql, Object obj) { protected void processUpdate(Update update, int index, String sql, Object obj) {
final Table table = update.getTable(); final Table table = update.getTable();
if (ignoreTable(table.getName())) {
// 过滤退出执行
return;
}
update.setWhere(this.andExpression(table, update.getWhere())); update.setWhere(this.andExpression(table, update.getWhere()));
} }
@ -99,10 +125,6 @@ public class DataPermissionInterceptor extends JsqlParserSupport implements Inne
*/ */
@Override @Override
protected void processDelete(Delete delete, int index, String sql, Object obj) { protected void processDelete(Delete delete, int index, String sql, Object obj) {
if (ignoreTable(delete.getTable().getName())) {
// 过滤退出执行
return;
}
delete.setWhere(this.andExpression(delete.getTable(), delete.getWhere())); delete.setWhere(this.andExpression(delete.getTable(), delete.getWhere()));
} }
@ -378,4 +400,116 @@ public class DataPermissionInterceptor extends JsqlParserSupport implements Inne
return new LongValue(1L); return new LongValue(1L);
} }
/**
* SQL {@link MappedStatementCache}
*
* @param ms MappedStatement
*/
private void addMappedStatementCache(MappedStatement ms) {
if (ContextHolder.getRewrite()) {
return;
}
// 有重写,进行添加
mappedStatementCache.addNoRewritable(ms, ContextHolder.getRules());
}
/**
* SQL 便 {@link DataPermissionRule}
*
* @author
*/
private static final class ContextHolder {
/**
* {@link MappedStatement}
*/
private static final ThreadLocal<List<DataPermissionRule>> RULES = new TransmittableThreadLocal<>();
/**
* SQL
*/
private static final ThreadLocal<Boolean> REWRITE = new TransmittableThreadLocal<>();
public static void init(List<DataPermissionRule> rules) {
RULES.set(rules);
REWRITE.set(false);
}
public static void clear() {
RULES.remove();
REWRITE.remove();
}
public static boolean getRewrite() {
return REWRITE.get();
}
public static void setRewrite(boolean rewrite) {
REWRITE.set(rewrite);
}
public static List<DataPermissionRule> getRules() {
return RULES.get();
}
}
/**
* {@link MappedStatement}
* {@link DataPermissionRule} {@link MappedStatement}
* SQL
*
* @author
*/
private static final class MappedStatementCache {
/**
*
*
* value{@link MappedStatement#getId()}
*/
private final Map<Class<? extends DataPermissionRule>, Set<String>> noRewritableMappedStatements = new ConcurrentHashMap<>();
/**
*
* ps
*
* @param ms MappedStatement
* @param rules
* @return
*/
public boolean noRewritable(MappedStatement ms, List<DataPermissionRule> rules) {
// 如果规则为空,说明无需重写
if (CollUtil.isEmpty(rules)) {
return true;
}
// 任一规则不在 noRewritableMap 中,则说明可能需要重写
for (DataPermissionRule rule : rules) {
Set<String> mappedStatementIds = noRewritableMappedStatements.get(rule.getClass());
if (!CollUtil.contains(mappedStatementIds, ms.getId())) { // 不存在,则说明可能要重写
return false;
}
}
return true;
}
/**
* MappedStatement
*
* @param ms MappedStatement
* @param rules
*/
public void addNoRewritable(MappedStatement ms, List<DataPermissionRule> rules) {
for (DataPermissionRule rule : rules) {
Set<String> mappedStatementIds = noRewritableMappedStatements.get(rule.getClass());
if (CollUtil.isNotEmpty(mappedStatementIds)) {
mappedStatementIds.add(ms.getId());
} else {
noRewritableMappedStatements.put(rule.getClass(), SetUtils.asSet(ms.getId()));
}
}
}
}
} }

View File

@ -1,9 +1,28 @@
package cn.iocoder.yudao.framework.datapermission.core.rule; package cn.iocoder.yudao.framework.datapermission.core.rule;
import java.util.List;
/** /**
* {@link DataPermissionRule} * {@link DataPermissionRule}
* 1. {@link DataPermissionRule} * {@link DataPermissionRule}
* 2. TODO *
* @author
*/ */
public interface DataPermissionRuleFactory { public interface DataPermissionRuleFactory {
/**
*
*
* @return
*/
List<DataPermissionRule> getDataPermissionRules();
/**
* Mapper
*
* @param mappedStatementId Mapper
* @return
*/
List<DataPermissionRule> getDataPermissionRule(String mappedStatementId);
} }

View File

@ -1,12 +1,20 @@
package cn.iocoder.yudao.framework.datapermission.core.interceptor; package cn.iocoder.yudao.framework.datapermission.core.interceptor;
import cn.iocoder.yudao.framework.datapermission.core.rule.DataPermissionRuleFactory;
import cn.iocoder.yudao.framework.test.core.ut.BaseMockitoUnitTest;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.mockito.InjectMocks;
import org.mockito.Mock;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
public class DataPermissionInterceptorTest { public class DataPermissionInterceptorTest extends BaseMockitoUnitTest {
private final DataPermissionInterceptor interceptor = new DataPermissionInterceptor(); @InjectMocks
private DataPermissionInterceptor interceptor;
@Mock
private DataPermissionRuleFactory ruleFactory;
@Test @Test
public void selectSingle() { public void selectSingle() {

View File

@ -15,7 +15,7 @@ public abstract class AbstractChannelMessage extends AbstractRedisMessage {
* *
* @return Channel * @return Channel
*/ */
@JsonIgnore // 避免序列化 @JsonIgnore // 避免序列化。原因是Redis 发布 Channel 消息的时候,已经会指定。
public abstract String getChannel(); public abstract String getChannel();
} }