从名字就可以看出来,它的功能是与MySQL数据库连接用的
首先,让我们连接数据库
import MySQLdb
try:
conn = MySQLdb.connect(host='localhost',user='root',passwd='',db='test',port=3306,charset='utf8')
print "Connect Successful !"
conn.close()
except MySQLdb.Error,e:
print "Mysql Error %d: %s" % (e.args[0], e.args[1])
保存为mysqldb_demo.py,运行,看一下结果。
可以看出来,如果MySQL数据库打开且账户密码正确的话就可以正确连接,并显示数据库版本,如果错误则报错并显示错误类型。
接下来,我们试一下数据库的增改删查和刷新。 先来看一下在数据库test中有一个表单test。 test中有三个选项,分别是name,id,sex,数据类型分别是char,int,char。
import MySQLdb
try:
conn = MySQLdb.connect(host='localhost',user='root',passwd='',db='test',port=3306)
print "Connect Successful !"
cur = conn.cursor()
cur.execute("SELECT * FROM test")
data = cur.fetchone()
print data
value = ["Windard",001,"man"]
cur.execute("INSERT INTO test(name,id,sex) VALUES(%s,%s,%s)",value)
#注意一定要有conn.commit()这句来提交,要不然不能真正的插入数据。
conn.commit()
cur.execute("SELECT * FROM test")
data = cur.fetchone()
print data
cur.close()
conn.close()
except MySQLdb.Error,e:
print "Mysql Error %d: %s" % (e.args[0], e.args[1])
保存为mysqldb_first.py,运行,看一下结果。
可以看到之前,在表单里并没有数据,在执行插入了之后有了一行数据。 注意,在执行插入之后一定要commmit()才能实行有效操作,不然不能写入数据库。
在这里注意一下,如果在你的数据中使用了中文的话,需要加入一下四行代码,来确定中文的正常读写。
conn.set_character_set('utf8')
cur.execute('SET NAMES utf8;')
cur.execute('SET CHARACTER SET utf8;')
cur.execute('SET character_set_connection=utf8;')
打开游标也可以使用 上下文管理器 ,这样可以自动执行关闭操作。
# coding=utf-8
import MySQLdb
conn = MySQLdb.Connection(host='127.0.0.1', user='root', passwd='123456', db='test', port=3306, charset='utf8')
with conn as cur:
cur.execute('SELECT * FROM user')
print 'sum:', cur.rowcount
for row in cur.fetchall():
print row
使用 上下文管理器还有一个好处,就是在 with 语句中的所有 SQL 执行作为一个事务 ,只有全部执行成功才会 commit ,否则会自动 rollback
一般的执行 SQL 语句是使用游标 cursor 操作的,我们也可以直接使用 connection 执行 SQL 语句。
# coding=utf-8
import MySQLdb
try:
conn = MySQLdb.connect(host='127.0.0.1', user='root', passwd='123456', db='test', port=3306, charset='utf8')
# 数据库连接信息
print "Host Info :", conn.get_host_info()
print "Server Info:", conn.get_server_info()
conn.query('SELECT * FROM user')
# fetch_row 默认只取一条数据
for row in conn.use_result().fetch_row(10):
print row
except MySQLdb.Error,e:
print "MySQL Error %d: %s" % tuple(e)
else:
conn.close()
再来看一个完整的增改删查的代码。
import MySQLdb
try:
conn = MySQLdb.connect(host='localhost',user='root',passwd='',db='test',port=3306)
print "Connect Successful !"
cur = conn.cursor()
#首先查询原始数据库状态
cur.execute("SELECT * FROM test ")
data = cur.fetchone()
print data
#插入一条数据
value = ["Windard",001,"man"]
cur.execute("INSERT INTO test(name,id,sex) VALUES(%s,%s,%s)",value)
conn.commit()
#查询插入数据库之后的状态
cur.execute("SELECT * FROM test ")
data = cur.fetchone()
print data
#更改数据库数据
cur.execute("UPDATE test SET id = 100 WHERE name = 'Windard'")
#查询更改数据之后的数据库数据
cur.execute("SELECT * FROM test ")
data = cur.fetchone()
print data
#删除数据库数据
cur.execute("DELETE FROM test WHERE name = 'Windard'")
#查询删除数据之后的数据库数据
cur.execute("SELECT * FROM test ")
data = cur.fetchone()
print data
cur.close()
conn.close()
except MySQLdb.Error,e:
print "Mysql Error %d: %s" % (e.args[0], e.args[1])
保存为mysqldb_second.py,运行,看一下结果。
这里包含完整的数据库增改删查的操作。
一般在 Python 中的参数化字符串使用百分号 %
传入,但是使用 SQL 查询的话,参数不用自己来完成,MySQL 的函数库会进行参数化,并自动为其加上单引号。
# -*- coding: utf-8 -*-
import MySQLdb
try:
#打开数据库连接
conn = MySQLdb.connect(host='localhost', user='root', passwd='', db='test')
print "Connect Successful !"
#用cursor()获得操作游标
cur = conn.cursor()
name = ['windard']
# 不要这样使用
cur.execute("SELECT * FROM user WHERE name='%s'" % name[0])
print cur.fetchone()
# 这样才是正确的
cur.execute("SELECT * FROM user WHERE name=%s", name)
print cur.fetchone()
# like 的时候是这样
cur.execute("SELECT * FROM user WHERE name LIKE %s", ['%%%s%%' % name[0]])
print cur.fetchone()
conn.close()
except MySQLdb.Error, e:
print "Mysql Error %d: %s" % (e.args[0], e.args[1])
为了防止 SQL 注入,建议使用 参数化 查询,MySQLdb 自动进行转义过滤。
# coding=utf-8
import MySQLdb
conn = MySQLdb.Connection(host='127.0.0.1', user='root', passwd='123456', db='test', port=3306, charset='utf8')
with conn as cur:
# secure
cur.execute('SELECT * FROM user WHERE id=%s AND name=%s', ["0 or 1=1; # -- ", 'windard'])
print 'sum:', cur.rowcount
for row in cur.fetchall():
print row
# insecure
cur.execute('SELECT * FROM user WHERE id=%s AND name=%s'%("0 or 1=1; # -- ", 'windard'))
print 'sum:', cur.rowcount
for row in cur.fetchall():
print row
MySQLdb 的转义函数是 Connection.literal(o)
,参数可以是字符串或者是列表。
# coding=utf-8
import MySQLdb
conn = MySQLdb.Connection(host='127.0.0.1', user='root', passwd='123456', db='test', port=3306, charset='utf8')
print conn.literal(["0 or 1=1; # -- ", 'windard'])
print conn.literal("0' or 1=1; # -- ")
转义结果为
("'0 or 1=1; # -- '", "'windard'")
'0\' or 1=1; # -- '
那我们试一下创建一个新的数据库和新的表单,插入大量的数据来试试。
import MySQLdb
try:
conn = MySQLdb.connect(host='localhost',user='root',passwd='',port=3306,charset='utf8')
print "Connect Successful !"
cur = conn.cursor()
#创建一个新的数据库名为python
cur.execute("CREATE DATABASE IF NOT EXISTS python")
#连接这个数据库
conn.select_db('python')
#创建一个新的表单test
cur.execute("CREATE TABLE test(id int,info varchar(20))")
#插入单个数据
value = [1,'windard']
cur.execute("INSERT INTO test VALUES(%s,%s)",value)
conn.commit()
#查看结果
cur.execute("SELECT * FROM test ")
data = cur.fetchone()
print data
#插入大量数据
values = []
for i in range(20):
values.append((i,'this is number :' + str(i)))
cur.executemany("INSERT INTO test VALUES(%s,%s)",values)
conn.commit()
#查看结果,此时execute()的返回值是插入数据得到的行数
print "All Database Table"
count = cur.execute("SELECT * FROM test ")
data = cur.fetchmany(count)
for item in data:
print item
#删除表单
cur.execute("DROP TABLE test ")
#删除数据库
cur.execute("DROP DATABASE python")
cur.close()
conn.close()
except MySQLdb.Error,e:
print "Mysql Error %d: %s" % (e.args[0], e.args[1])
保存为mysqldb_third.py,运行,看一下结果。
在这里连接数据库的时候也加上了数据库使用的编码格式,utf8,在使用的时候可以避免乱码的出现。
import MySQLdb
try:
conn = MySQLdb.connect(host='localhost',user='root',passwd='',db='test',port=3306)
print "Connect Successful !"
cur = conn.cursor()
cur.execute("SELECT * FROM test")
data = cur.fetchone()
print data
value = ["Windard",001,"man"]
try:
cur.execute("INSERT INTO test(name,id,sex) VALUES(%s,%s,%s)",value)
#注意一定要有conn.commit()这句来提交,要不然不能真正的插入数据。
conn.commit()
except :
#发生错误时回滚
conn.rollback()
cur.execute("SELECT * FROM test")
data = cur.fetchall()
for item in data:
fname = item[0]
fid = item[1]
fsex = item[2]
print "name = %s ,id = %s , sex = %s " %(fname ,fid ,fsex)
cur.close()
conn.close()
except MySQLdb.Error,e:
print "Mysql Error %d: %s" % (e.args[0], e.args[1])
保存为mysqldb_error.py,运行,看一下结果。
这个代码演示了发生错误时候回滚的操作,rollback()能够把游标指针指到错误发生之前的位置。 还有fetchall()即一次取得全部的数据。 还有其他几个功能类似的函数fetchone(),一次取得一个数据,fetchmany(num),一次取得num个数据。
# -*- coding: utf-8 -*-
import chardet
import MySQLdb
class Database(object):
"""Database Control For Beginner"""
def __init__(self, host='127.0.0.1', user='root', password='', db='',
port=3306, charset='utf8', debug=False):
self.host = host
self.user = user
self.password = password
self.port = port
self.charset = charset
self.db = db
self._debug = debug
try:
self._conn = MySQLdb.Connection(host=self.host, user=self.user,
passwd=self.password,
db=self.db, port=self.port,
charset=self.charset)
self._conn.set_character_set('utf8')
self._cur = self._conn.cursor()
self._cur.execute('SET NAMES utf8;')
self._cur.execute('SET CHARACTER SET utf8;')
self._cur.execute('SET character_set_connection=utf8;')
except Exception, e:
if self._debug:
print tuple(e)
def exec_(self, query, paras=None):
try:
if self._debug:
print query
if paras:
self._cur.execute(query, paras)
else:
self._cur.execute(query)
result = {'code': 1000}
result['content'] = self._cur.fetchall()
return result
except Exception, e:
self._conn.rollback()
result = {'code': 1001, 'content': tuple(e)}
return result
def get(self, table, filed=['*'], options={}):
try:
select, where, order, limit, paras, conds = '', '', '', '', [], []
if '*' in filed:
select = '*'
else:
select = ','.join(filed)
if options.get('where', None):
for key, value in options['where'].items():
if type(value) == str:
value = value.decode(chardet.detect(value)['encoding'])
conds.append("%s %%s" % key)
paras.append(value)
where = 'WHERE '
where += ' AND '.join(conds)
if options.get('order', None):
order = ' ORDER BY ' + options['order']
if options.get('limit', None):
limit = ' LIMIT ' + ','.join(map(str, options['limit']))
return self.exec_("SELECT %s FROM %s %s %s %s"
% (select, table, where, order, limit), paras)
except Exception, e:
return {'code': 1002, 'content': tuple(e)}
def set(self, table, values, options={}):
try:
conds, where, paras, stats = [], '', [], []
for key, value in values.items():
if type(value) == str:
value = value.decode(chardet.detect(value)['encoding'])
conds.append("%s %%s" % key)
paras.append(value)
if options:
for key, value in options.items():
if type(value) == str:
value = value.decode(chardet.detect(value)['encoding'])
stats.append("%s %%s" % key)
paras.append(value)
where = 'WHERE '
where += ' AND '.join(stats)
return self.exec_('UPDATE %s SET %s %s'
% (table, ' AND '.join(conds), where), paras)
except Exception, e:
return {'code': 1003, 'content': tuple(e)}
def new(self, table, values, options=[]):
try:
conds, paras, stats = [], [], ''
for value in values:
if type(value) == str:
value = value.decode(chardet.detect(value)['encoding'])
conds.append("%s")
paras.append(value)
if options:
stats += '(' + ','.join(options) + ')'
return self.exec_('INSERT INTO %s%s VALUES(%s)'
% (table, stats, ','.join(conds)), paras)
except Exception, e:
return {'code': 1004, 'content': tuple(e)}
def del_(self, table, options={}):
try:
where, paras, stats = '', [], []
if options:
for key, value in options.items():
if type(value) == str:
value = value.decode(chardet.detect(value)['encoding'])
stats.append("%s %%s" % key)
paras.append(value)
where = 'WHERE '
where += ' AND '.join(stats)
return self.exec_('DELETE FROM %s %s' % (table, where), paras)
except Exception, e:
return {'code': 1005, 'content': tuple(e)}
def __del__(self):
try:
try:
self._conn.commit()
except Exception, e:
if self._debug:
print tuple(e)
self._conn.rollback()
self._cur.close()
self._conn.close()
except Exception, e:
if self._debug:
print tuple(e)
"""
import databasescontrol
conn = databasescontrol.Database(password="XXXXXX")
print conn.exec_('use test;')
# print conn.get('user')
print conn.get("user", options={'where':{'id >':1, 'name like':'%m%'}, "limit":[1], "order":'id desc'})
# print conn.get("user", options={'where':{'id >':1, 'name =':'mary'}, "limit":[1, 2], "order":'id desc'})
# print conn.get("user", filed=['id', 'name', 'passwd'], options={'where':{'id >':1, 'name =':'姓名'}})
# print conn.get('user', options={'where': {'id in':(1,2)}})
# print conn.set("user", {"name =":"mary"}, {'id >':2, 'name=':'wocao'})
# print conn.new("user", ["name","password", 6], ["name","passwd", 'id'])
# print conn.new('user', [7, 'hello', 'world', 'nihao', 9.0])
# print conn.new('user', [8, '中文', 'world', '测试', 9.0])
# print conn.del_('user', {'name=':'姓名'})
print conn.get('user')
"""
真正的数据库操作模块。。。
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
设计db模块的原因:
1. 更简单的操作数据库
一次数据访问: 数据库连接 => 游标对象 => 执行SQL => 处理异常 => 清理资源。
db模块对这些过程进行封装,使得用户仅需关注SQL执行。
2. 数据安全
用户请求以多线程处理时,为了避免多线程下的数据共享引起的数据混乱,
需要将数据连接以 ThreadLocal 对象传入。
设计db接口:
1.设计原则:
根据上层调用者设计简单易用的 API 接口
2. 调用接口
1. 初始化数据库连接信息
create_engine封装了如下功能:
1. 为数据库连接 准备需要的配置信息
2. 创建数据库连接(由生成的全局对象engine的 connect方法提供)
import transwarp as db
db.create_engine(user='root',
password='password',
database='test',
host='127.0.0.1',
port=3306)
2. 执行SQL DML
select 函数封装了如下功能:
1.支持一个数据库连接里执行多个SQL语句
2.支持链接的自动获取和释放
使用样例:
users = db.select('select * from user')
# users =>
# [
# { "id": 1, "name": "Michael"},
# { "id": 2, "name": "Bob"},
# { "id": 3, "name": "Adam"}
# ]
3. 支持事物
transaction 函数封装了如下功能:
1. 事务也可以嵌套,内层事务会自动合并到外层事务中,这种事务模型足够满足99%的需求
"""
import time
import uuid
import functools
import threading
import logging
# global engine object:
engine = None
def next_id(t=None):
"""
生成一个唯一id 由 当前时间 + 随机数(由伪随机数得来)拼接得到
"""
if t is None:
t = time.time()
return '%015d%s000' % (int(t * 1000), uuid.uuid4().hex)
def _profiling(start, sql=''):
"""
用于剖析sql的执行时间
"""
t = time.time() - start
if t > 0.1:
logging.warning('[PROFILING] [DB] %s: %s' % (t, sql))
else:
logging.info('[PROFILING] [DB] %s: %s' % (t, sql))
def create_engine(user, password, database, host='127.0.0.1', port=3306, **kw):
"""
db模型的核心函数,用于连接数据库, 生成全局对象engine,
engine对象持有数据库连接
"""
import mysql.connector
global engine
if engine is not None:
raise DBError('Engine is already initialized.')
params = dict(user=user, password=password, database=database, host=host, port=port)
defaults = dict(use_unicode=True, charset='utf8', collation='utf8_general_ci', autocommit=False)
for k, v in defaults.iteritems():
params[k] = kw.pop(k, v)
params.update(kw)
params['buffered'] = True
engine = _Engine(lambda: mysql.connector.connect(**params))
# test connection...
logging.info('Init mysql engine <%s> ok.' % hex(id(engine)))
def connection():
"""
db模块核心函数,用于获取一个数据库连接
通过_ConnectionCtx对 _db_ctx封装,使得惰性连接可以自动获取和释放,
也就是可以使用 with语法来处理数据库连接
_ConnectionCtx 实现with语法
^
|
_db_ctx _DbCtx实例
^
|
_DbCtx 获取和释放惰性连接
^
|
_LasyConnection 实现惰性连接
"""
return _ConnectionCtx()
def with_connection(func):
"""
设计一个装饰器 替换with语法,让代码更优雅
比如:
@with_connection
def foo(*args, **kw):
f1()
f2()
f3()
"""
@functools.wraps(func)
def _wrapper(*args, **kw):
with _ConnectionCtx():
return func(*args, **kw)
return _wrapper
def transaction():
"""
db模块核心函数 用于实现事物功能
支持事物:
with db.transaction():
db.select('...')
db.update('...')
db.update('...')
支持事物嵌套:
with db.transaction():
transaction1
transaction2
...
"""
return _TransactionCtx()
def with_transaction(func):
"""
设计一个装饰器 替换with语法,让代码更优雅
比如:
@with_transaction
def do_in_transaction():
>>> @with_transaction
... def update_profile(id, name, rollback):
... u = dict(id=id, name=name, email='%s@test.org' % name, passwd=name, last_modified=time.time())
... insert('user', **u)
... update('update user set passwd=? where id=?', name.upper(), id)
... if rollback:
... raise StandardError('will cause rollback...')
>>> update_profile(8080, 'Julia', False)
>>> select_one('select * from user where id=?', 8080).passwd
u'JULIA'
>>> update_profile(9090, 'Robert', True)
Traceback (most recent call last):
...
StandardError: will cause rollback...
"""
@functools.wraps(func)
def _wrapper(*args, **kw):
start = time.time()
with _TransactionCtx():
func(*args, **kw)
_profiling(start)
return _wrapper
@with_connection
def _select(sql, first, *args):
"""
执行SQL,返回一个结果 或者多个结果组成的列表
"""
global _db_ctx
cursor = None
sql = sql.replace('?', '%s')
logging.info('SQL: %s, ARGS: %s' % (sql, args))
try:
cursor = _db_ctx.connection.cursor()
cursor.execute(sql, args)
if cursor.description:
names = [x[0] for x in cursor.description]
if first:
values = cursor.fetchone()
if not values:
return None
return Dict(names, values)
return [Dict(names, x) for x in cursor.fetchall()]
finally:
if cursor:
cursor.close()
def select_one(sql, *args):
"""
执行SQL 仅返回一个结果
如果没有结果 返回None
如果有1个结果,返回一个结果
如果有多个结果,返回第一个结果
>>> u1 = dict(id=100, name='Alice', email='alice@test.org', passwd='ABC-12345', last_modified=time.time())
>>> u2 = dict(id=101, name='Sarah', email='sarah@test.org', passwd='ABC-12345', last_modified=time.time())
>>> insert('user', **u1)
1
>>> insert('user', **u2)
1
>>> u = select_one('select * from user where id=?', 100)
>>> u.name
u'Alice'
>>> select_one('select * from user where email=?', 'abc@email.com')
>>> u2 = select_one('select * from user where passwd=? order by email', 'ABC-12345')
>>> u2.name
u'Alice'
"""
return _select(sql, True, *args)
def select_int(sql, *args):
"""
执行一个sql 返回一个数值,
注意仅一个数值,如果返回多个数值将触发异常
>>> u1 = dict(id=96900, name='Ada', email='ada@test.org', passwd='A-12345', last_modified=time.time())
>>> u2 = dict(id=96901, name='Adam', email='adam@test.org', passwd='A-12345', last_modified=time.time())
>>> insert('user', **u1)
1
>>> insert('user', **u2)
1
>>> select_int('select count(*) from user')
5
>>> select_int('select count(*) from user where email=?', 'ada@test.org')
1
>>> select_int('select count(*) from user where email=?', 'notexist@test.org')
0
>>> select_int('select id from user where email=?', 'ada@test.org')
96900
>>> select_int('select id, name from user where email=?', 'ada@test.org')
Traceback (most recent call last):
...
MultiColumnsError: Expect only one column.
"""
d = _select(sql, True, *args)
if len(d) != 1:
raise MultiColumnsError('Expect only one column.')
return d.values()[0]
def select(sql, *args):
"""
执行sql 以列表形式返回结果
>>> u1 = dict(id=200, name='Wall.E', email='wall.e@test.org', passwd='back-to-earth', last_modified=time.time())
>>> u2 = dict(id=201, name='Eva', email='eva@test.org', passwd='back-to-earth', last_modified=time.time())
>>> insert('user', **u1)
1
>>> insert('user', **u2)
1
>>> L = select('select * from user where id=?', 900900900)
>>> L
[]
>>> L = select('select * from user where id=?', 200)
>>> L[0].email
u'wall.e@test.org'
>>> L = select('select * from user where passwd=? order by id desc', 'back-to-earth')
>>> L[0].name
u'Eva'
>>> L[1].name
u'Wall.E'
"""
return _select(sql, False, *args)
@with_connection
def _update(sql, *args):
"""
执行update 语句,返回update的行数
"""
global _db_ctx
cursor = None
sql = sql.replace('?', '%s')
logging.info('SQL: %s, ARGS: %s' % (sql, args))
try:
cursor = _db_ctx.connection.cursor()
cursor.execute(sql, args)
r = cursor.rowcount
if _db_ctx.transactions == 0:
# no transaction enviroment:
logging.info('auto commit')
_db_ctx.connection.commit()
return r
finally:
if cursor:
cursor.close()
def update(sql, *args):
"""
执行update 语句,返回update的行数
>>> u1 = dict(id=1000, name='Michael', email='michael@test.org', passwd='123456', last_modified=time.time())
>>> insert('user', **u1)
1
>>> u2 = select_one('select * from user where id=?', 1000)
>>> u2.email
u'michael@test.org'
>>> u2.passwd
u'123456'
>>> update('update user set email=?, passwd=? where id=?', 'michael@example.org', '654321', 1000)
1
>>> u3 = select_one('select * from user where id=?', 1000)
>>> u3.email
u'michael@example.org'
>>> u3.passwd
u'654321'
>>> update('update user set passwd=? where id=?', '***', '123')
0
"""
return _update(sql, *args)
def insert(table, **kw):
"""
执行insert语句
>>> u1 = dict(id=2000, name='Bob', email='bob@test.org', passwd='bobobob', last_modified=time.time())
>>> insert('user', **u1)
1
>>> u2 = select_one('select * from user where id=?', 2000)
>>> u2.name
u'Bob'
>>> insert('user', **u2)
Traceback (most recent call last):
...
IntegrityError: 1062 (23000): Duplicate entry '2000' for key 'PRIMARY'
"""
cols, args = zip(*kw.iteritems())
sql = 'insert into `%s` (%s) values (%s)' % (table, ','.join(['`%s`' % col for col in cols]), ','.join(['?' for i in range(len(cols))]))
return _update(sql, *args)
class Dict(dict):
"""
字典对象
实现一个简单的可以通过属性访问的字典,比如 x.key = value
"""
def __init__(self, names=(), values=(), **kw):
super(Dict, self).__init__(**kw)
for k, v in zip(names, values):
self[k] = v
def __getattr__(self, key):
try:
return self[key]
except KeyError:
raise AttributeError(r"'Dict' object has no attribute '%s'" % key)
def __setattr__(self, key, value):
self[key] = value
class DBError(Exception):
pass
class MultiColumnsError(DBError):
pass
class _Engine(object):
"""
数据库引擎对象
用于保存 db模块的核心函数:create_engine 创建出来的数据库连接
"""
def __init__(self, connect):
self._connect = connect
def connect(self):
return self._connect()
class _LasyConnection(object):
"""
惰性连接对象
仅当需要cursor对象时,才连接数据库,获取连接
"""
def __init__(self):
self.connection = None
def cursor(self):
if self.connection is None:
_connection = engine.connect()
logging.info('[CONNECTION] [OPEN] connection <%s>...' % hex(id(_connection)))
self.connection = _connection
return self.connection.cursor()
def commit(self):
self.connection.commit()
def rollback(self):
self.connection.rollback()
def cleanup(self):
if self.connection:
_connection = self.connection
self.connection = None
logging.info('[CONNECTION] [CLOSE] connection <%s>...' % hex(id(connection)))
_connection.close()
class _DbCtx(threading.local):
"""
db模块的核心对象, 数据库连接的上下文对象,负责从数据库获取和释放连接
取得的连接是惰性连接对象,因此只有调用cursor对象时,才会真正获取数据库连接
该对象是一个 Thread local对象,因此绑定在此对象上的数据 仅对本线程可见
"""
def __init__(self):
self.connection = None
self.transactions = 0
def is_init(self):
"""
返回一个布尔值,用于判断 此对象的初始化状态
"""
return self.connection is not None
def init(self):
"""
初始化连接的上下文对象,获得一个惰性连接对象
"""
logging.info('open lazy connection...')
self.connection = _LasyConnection()
self.transactions = 0
def cleanup(self):
"""
清理连接对象,关闭连接
"""
self.connection.cleanup()
self.connection = None
def cursor(self):
"""
获取cursor对象, 真正取得数据库连接
"""
return self.connection.cursor()
# thread-local db context:
_db_ctx = _DbCtx()
class _ConnectionCtx(object):
"""
因为_DbCtx实现了连接的 获取和释放,但是并没有实现连接
的自动获取和释放,_ConnectCtx在 _DbCtx基础上实现了该功能,
因此可以对 _ConnectCtx 使用with 语法,比如:
with connection():
pass
with connection():
pass
"""
def __enter__(self):
"""
获取一个惰性连接对象
"""
global _db_ctx
self.should_cleanup = False
if not _db_ctx.is_init():
_db_ctx.init()
self.should_cleanup = True
return self
def __exit__(self, exctype, excvalue, traceback):
"""
释放连接
"""
global _db_ctx
if self.should_cleanup:
_db_ctx.cleanup()
class _TransactionCtx(object):
"""
事务嵌套比Connection嵌套复杂一点,因为事务嵌套需要计数,
每遇到一层嵌套就+1,离开一层嵌套就-1,最后到0时提交事务
"""
def __enter__(self):
global _db_ctx
self.should_close_conn = False
if not _db_ctx.is_init():
# needs open a connection first:
_db_ctx.init()
self.should_close_conn = True
_db_ctx.transactions += 1
logging.info('begin transaction...' if _db_ctx.transactions == 1 else 'join current transaction...')
return self
def __exit__(self, exctype, excvalue, traceback):
global _db_ctx
_db_ctx.transactions -= 1
try:
if _db_ctx.transactions == 0:
if exctype is None:
self.commit()
else:
self.rollback()
finally:
if self.should_close_conn:
_db_ctx.cleanup()
def commit(self):
global _db_ctx
logging.info('commit transaction...')
try:
_db_ctx.connection.commit()
logging.info('commit ok.')
except:
logging.warning('commit failed. try rollback...')
_db_ctx.connection.rollback()
logging.warning('rollback ok.')
raise
def rollback(self):
global _db_ctx
logging.warning('rollback transaction...')
_db_ctx.connection.rollback()
logging.info('rollback ok.')
if __name__ == '__main__':
logging.basicConfig(level=logging.DEBUG)
create_engine('root', 'password', 'test', '127.0.0.1')
update('drop table if exists user')
update('create table user (id int primary key, name text, email text, passwd text, last_modified real)')
import doctest
doctest.testmod()
"""
import time
import transwarp as db
db.create_engine(user='root',
password='password',
database='test',
host='127.0.0.1',
port=3306)
u1 = dict(id=120, name='Wall.E', email='wall.e@test.org', passwd='back-to-earth', last_modified=time.time())
db.insert('user', **u1)
"""